Skip to content

Commit

Permalink
[CONTRIB][CC] Enhance cc.cross_compiler (#4817)
Browse files Browse the repository at this point in the history
* [CONTRIB][CC] Enhance cc.cross_compiler

- Enhance cc.cross_compiler to take str argument.
- Remove cc.build_create_shared_func as it is dupilicated with cross_compiler
- Add examples to cc.cross_compiler

* address review comments
  • Loading branch information
tqchen authored Feb 6, 2020
1 parent 5ea4f0d commit 19d0d15
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
69 changes: 42 additions & 27 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,38 +87,23 @@ def get_target_triple():
create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)

def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd.

Parameters
----------
options : List[str]
The list of additional options string.
def cross_compiler(compile_func,
options=None,
output_format=None,
get_target_triple=None):
"""Create a cross compiler function by specializing compile_func with options.
compile_cmd : Optional[str]
The compiler command.
Returns
-------
create_shared_wrapper : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library or to autotvm.LocalBuilder.
"""
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper
This function can be used to construct compile functions that
can be passed to AutoTVM measure or export_library.
def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function.
Parameters
----------
compile_func : Callable[[str, str, Optional[str]], None]
compile_func : Union[str, Callable[[str, str, Optional[str]], None]]
Function that performs the actual compilation
base_options : Optional[List[str]]
options : Optional[List[str]]
List of additional optional string.
output_format : Optional[str]
Expand All @@ -131,14 +116,44 @@ def cross_compiler(compile_func, base_options=None, output_format="so", get_targ
-------
fcompile : Callable[[str, str, Optional[str]], None]
A compilation function that can be passed to export_library.
Examples
--------
.. code-block:: python
from tvm.contrib import cc, ndk
# export using arm gcc
mod = build_runtime_module()
mod.export_library(path_dso,
cc.cross_compiler("arm-linux-gnueabihf-gcc"))
# specialize ndk compilation options.
specialized_ndk = cc.cross_compiler(
ndk.create_shared,
["--sysroot=/path/to/sysroot", "-shared", "-fPIC", "-lm"])
mod.export_library(path_dso, specialized_ndk)
"""
if base_options is None:
base_options = []
base_options = [] if options is None else options
kwargs = {}

# handle case where compile_func is the name of the cc
if isinstance(compile_func, str):
kwargs = {"cc" : compile_func}
compile_func = create_shared


def _fcompile(outputs, objects, options=None):
all_options = base_options
if options is not None:
all_options += options
compile_func(outputs, objects, options=all_options)
compile_func(outputs, objects, options=all_options, **kwargs)

if not output_format and hasattr(compile_func, "output_format"):
output_format = compile_func.output_format
output_format = output_format if output_format else "so"

if not get_target_triple and hasattr(compile_func, "get_target_triple"):
get_target_triple = compile_func.get_target_triple

_fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def check_device(device):
raise ValueError("Unsupported platform")

path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
# test cross compiler function
f.export_library(path_dso, cc.cross_compiler("g++"))

f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
Expand All @@ -134,8 +135,8 @@ def check_stackvm(device):
name = "myadd_%s" % device
f = tvm.build(s, [A, B], device, "stackvm", name=name)
path_dso = temp.relpath("dev_lib.stackvm")
#f.export_library(path_dso)
#f1 = tvm.module.load(path_dso)
f.export_library(path_dso)
f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f(a, b)
Expand Down

0 comments on commit 19d0d15

Please sign in to comment.