Skip to content

Commit

Permalink
[ADDON] Allow piggy back nvcc compiler and code (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Feb 7, 2017
1 parent 8837798 commit 08505e3
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 24 deletions.
14 changes: 12 additions & 2 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ typedef enum {
kArrayHandle = 5U,
kTVMType = 6U,
kNodeHandle = 7U,
kStr = 8U,
kFuncHandle = 9U
kFuncHandle = 8U,
kStr = 9U,
kBytes = 10U
} TVMTypeCode;

/*!
Expand Down Expand Up @@ -86,6 +87,15 @@ typedef union {
TVMType v_type;
} TVMValue;

/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;

/*!
* \brief The device type
*/
Expand Down
17 changes: 14 additions & 3 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class PackedFunc {
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static bool GlobalExist(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
Expand Down Expand Up @@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
operator TVMType() const {
if (type_code_ == kStr) {
Expand Down Expand Up @@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
case kStr:
case kBytes: {
SwitchToClass<std::string>(kStr, other);
break;
}
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access, too-many-branches
"""Symbolic configuration API."""
from __future__ import absolute_import

Expand All @@ -9,7 +9,7 @@

from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._node import NodeBase, SliceBase, convert_to_node
Expand Down Expand Up @@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
Expand Down
30 changes: 25 additions & 5 deletions python/tvm/_ctypes/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class TypeCode(object):
ARRAY_HANDLE = 5
TVM_TYPE = 6
NODE_HANDLE = 7
STR = 8
FUNC_HANDLE = 9
FUNC_HANDLE = 8
STR = 9
BYTES = 10

def _api_type(code):
"""create a type accepted by API"""
Expand Down Expand Up @@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]

class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]


TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
Expand All @@ -110,20 +116,34 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle)
return handle

def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res


RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}


C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
1 change: 1 addition & 0 deletions python/tvm/addon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Addon utilities to python"""
55 changes: 55 additions & 0 deletions python/tvm/addon/nvcc_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Util to compile with NVCC"""
import os
import sys
import tempfile
import subprocess

def compile_source(code, target="cubin"):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target: str
The target format
Return
------
cubin : bytearray
The bytearray of the cubin
"""
temp_dir = tempfile.mkdtemp()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
path_code = os.path.join(temp_dir, "my_kernel.cu")
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)

with open(path_code, "w") as out_file:
out_file.write(code)

cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
cmd += [path_code]
args = ' '.join(cmd)

proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()

if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(path_target, "rb").read())
os.remove(path_code)
if os.path.exists(path_target):
os.remove(path_target)
os.rmdir(temp_dir)
return cubin
13 changes: 9 additions & 4 deletions src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
}
// functions
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
stmt = IRMutator::Mutate(stmt);
return stmt;
}
Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
Expand All @@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.max_level = stack_.back().max_level;
stack_.pop_back();
CHECK(expr.defined());
return expr;
}
// call produce to get a cache entry.
Expand Down Expand Up @@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
// subroutine to do produce
Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
CHECK_NE(stack_.size(), 0U);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
Expand All @@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
return ret_entry_.value;
}
// convert sum to expr
Expand Down Expand Up @@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
}
}
}
return vsum;
if (vsum.defined()) {
return vsum;
} else {
return make_zero(t);
}
}
};

Expand Down
14 changes: 13 additions & 1 deletion src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n';
}
std::string ptx = runtime::NVRTCCompile(os.str());
std::string code = os.str();

if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
ptx = f(code).operator std::string();
} else {
ptx = runtime::NVRTCCompile(os.str());
}
std::unordered_map<LoweredFunc, PackedFunc> ret;

runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/packed_func_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
return *(it->second);
}

bool PackedFunc::GlobalExist(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
return it != r->fmap.end();
}

std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys;
Expand Down
22 changes: 16 additions & 6 deletions tests/python/integration/test_gemm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import tvm
from tvm.addon import nvcc_compiler
import numpy as np

@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx

@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code

def test_gemm():
# graph
nn = 1024
Expand All @@ -23,7 +35,6 @@ def test_gemm():
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
#s[CC].set_scope("global")
s[BB].set_scope("shared")

scale = 8
Expand Down Expand Up @@ -60,8 +71,6 @@ def check_device(target):
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]:
print(c)
if target == "cuda":
ctx = tvm.gpu(0)
else:
Expand All @@ -77,13 +86,14 @@ def check_device(target):
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
for i in range(4):
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)

tvm.init_opencl()
check_device("cuda")
check_device("opencl")
#tvm.init_opencl()
#check_device("opencl")

if __name__ == "__main__":
test_gemm()
10 changes: 9 additions & 1 deletion tests/python/unittest/test_runtime_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,17 @@ def myfunc(*args):
assert isinstance(f, tvm.nd.Function)
f(*targs)

def test_byte_array():
s = "hello"
a = bytearray(s, encoding="ascii")

def myfunc(ss):
assert ss == a
f = tvm.convert(myfunc)
f(a)

if __name__ == "__main__":
test_function()
test_convert()
test_get_global()
test_return_func()
test_byte_array()

0 comments on commit 08505e3

Please sign in to comment.