Skip to content

Commit

Permalink
[RUNTIME][SDACCEL] Add support for multiple kernels (apache#1424)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum authored and tqchen committed Jul 14, 2018
1 parent b404385 commit 12cf343
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 29 deletions.
44 changes: 25 additions & 19 deletions python/tvm/contrib/sdaccel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,21 @@


@register_func("tvm_callback_sdaccel_compile")
def compile_vhls(code, kernel):
def compile_vhls(kernel_info):
"""Compile Vivado HLS code for SDAccel.
Parameters
----------
code : str
The Vivado HLS code.
kernel : str
The kernel to compile or link.
kernel_info : list of (str, str)
List of kernel information. The kernel information is a tuple of
function name and source code.
Return
------
xclbin : bytearray
The bytearray of the xclbin
"""
tmp_dir = util.tempdir()
tmp_cpp = tmp_dir.relpath("input.cpp")
tmp_xo = tmp_dir.relpath("output.xo")
tmp_xclbin = tmp_dir.relpath("output.xclbin")

with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))

sdk = os.environ.get("XILINX_SDX", None)
xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc"
Expand All @@ -41,15 +33,29 @@ def compile_vhls(code, kernel):
if platform is None:
raise RuntimeError("No Xlinx device specified.")

# build xo
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", kernel] + \
advanced_params + [tmp_cpp]
returncode = subprocess.call(args)
if returncode != 0:
raise RuntimeError("Compile error")
tmp_xo_files = []
for funcname, code in kernel_info:
funcname = funcname.value
code = code.value

tmp_cpp = tmp_dir.relpath(funcname + ".cpp")
tmp_xo = tmp_dir.relpath(funcname + ".xo")

with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))

# build xo
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", funcname] + \
advanced_params + [tmp_cpp]
returncode = subprocess.call(args)
if returncode != 0:
raise RuntimeError("Compile error")

tmp_xo_files.append(tmp_xo)

# build xclbin
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin, tmp_xo] + \
tmp_xclbin = tmp_dir.relpath("output.xclbin")
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin] + tmp_xo_files + \
advanced_params
returncode = subprocess.call(args)
if returncode != 0:
Expand Down
25 changes: 16 additions & 9 deletions src/codegen/codegen_vhls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,33 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs) {
bool output_ssa = false;
CodeGenVivadoHLS cg;

CHECK_EQ(funcs.size(), 1);
const std::string funcname = funcs[0]->name;

// Generate source code for get_source().
cg.Init(output_ssa);

for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
std::string whole_code = cg.Finish();

// Generate source code for compilation.
Array<Array<Expr> > kernel_info;
for (LoweredFunc f : funcs) {
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
cg.AddFunction(f);
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
}
kernel_info.push_back(Array<Expr>({f->name, code}));
}

std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
xclbin = (*f)(code, funcname).operator std::string();
xclbin = (*f)(kernel_info).operator std::string();
} else {
LOG(FATAL) << "Cannot compile Vivado HLS code.";
}
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), code);
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code);
}

TVM_REGISTER_API("codegen.build_sdaccel")
Expand Down
41 changes: 40 additions & 1 deletion tests/python/integration/test_ewise_fpga.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_exp():
s[B].bind(px, tvm.thread_axis("pipeline"))

# one line to build the function.
def check_device(device, host="stackvm"):
def check_device(device, host="llvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
Expand All @@ -42,5 +42,44 @@ def check_device(device, host="stackvm"):
check_device("sdaccel")


def test_multi_kernel():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: A(*i) + C(*i), name='D')
s = tvm.create_schedule(D.op)
# create iter var and assign them tags.
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
px, x = s[D].split(D.op.axis[0], nparts=1)
s[D].bind(px, tvm.thread_axis("pipeline"))

# one line to build the function.
def check_device(device, host="llvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
if not ctx.exist:
return
fadd = tvm.build(s, [A, B, C, D],
device, host,
name="myadd")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), ctx)
d = tvm.nd.array(np.random.uniform(size=n).astype(D.dtype), ctx)
fadd(a, b, c, d)
np.testing.assert_allclose(
d.asnumpy(), a.asnumpy() * 2 + b.asnumpy(), rtol=1e-5)

check_device("sdaccel")


if __name__ == "__main__":
test_exp()
test_multi_kernel()

0 comments on commit 12cf343

Please sign in to comment.