diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 0c836144577b..1d368347686b 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -87,38 +87,23 @@ def get_target_triple(): create_shared.get_target_triple = get_target_by_dump_machine( "g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None) -def build_create_shared_func(options=None, compile_cmd="g++"): - """Build create_shared function with particular default options and compile_cmd. - Parameters - ---------- - options : List[str] - The list of additional options string. +def cross_compiler(compile_func, + options=None, + output_format=None, + get_target_triple=None): + """Create a cross compiler function by specializing compile_func with options. - compile_cmd : Optional[str] - The compiler command. - - Returns - ------- - create_shared_wrapper : Callable[[str, str, Optional[str]], None] - A compilation function that can be passed to export_library or to autotvm.LocalBuilder. - """ - def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd): - create_shared(output, objects, options, compile_cmd) - create_shared_wrapper.output_format = create_shared.output_format - create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd) - return create_shared_wrapper + This function can be used to construct compile functions that + can be passed to AutoTVM measure or export_library. -def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None): - """Create a cross compiler function. - Parameters ---------- - compile_func : Callable[[str, str, Optional[str]], None] + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] Function that performs the actual compilation - base_options : Optional[List[str]] + options : Optional[List[str]] List of additional optional string. output_format : Optional[str] @@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ ------- fcompile : Callable[[str, str, Optional[str]], None] A compilation function that can be passed to export_library. + + Examples + -------- + .. code-block:: python + + from tvm.contrib import cc, ndk + # export using arm gcc + mod = build_runtime_module() + mod.export_library(path_dso, + cc.cross_compiler("arm-linux-gnueabihf-gcc")) + # specialize ndk compilation options. + specialized_ndk = cc.cross_compiler( + ndk.create_shared, + ["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"]) + mod.export_library(path_dso, specialized_ndk) """ - if base_options is None: - base_options = [] + base_options = [] if options is None else options + kwargs = {} + + # handle case where compile_func is the name of the cc + if isinstance(compile_func, str): + kwargs = {"cc" : compile_func} + compile_func = create_shared + + def _fcompile(outputs, objects, options=None): all_options = base_options if options is not None: all_options += options - compile_func(outputs, objects, options=all_options) + compile_func(outputs, objects, options=all_options, **kwargs) + + if not output_format and hasattr(compile_func, "output_format"): + output_format = compile_func.output_format + output_format = output_format if output_format else "so" + + if not get_target_triple and hasattr(compile_func, "get_target_triple"): + get_target_triple = compile_func.get_target_triple + _fcompile.output_format = output_format _fcompile.get_target_triple = get_target_triple return _fcompile diff --git a/tests/python/unittest/test_module_load.py b/tests/python/unittest/test_module_load.py index e8e43352987e..b1ef1c6fbb17 100644 --- a/tests/python/unittest/test_module_load.py +++ b/tests/python/unittest/test_module_load.py @@ -113,7 +113,8 @@ def check_device(device): raise ValueError("Unsupported platform") path_dso = temp.relpath("dev_lib.so") - f.export_library(path_dso) + # test cross compiler function + f.export_library(path_dso, cc.cross_compiler("g++")) f1 = tvm.module.load(path_dso) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) @@ -134,8 +135,8 @@ def check_stackvm(device): name = "myadd_%s" % device f = tvm.build(s, [A, B], device, "stackvm", name=name) path_dso = temp.relpath("dev_lib.stackvm") - #f.export_library(path_dso) - #f1 = tvm.module.load(path_dso) + f.export_library(path_dso) + f1 = tvm.module.load(path_dso) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) f(a, b)