Skip to content

Commit

Permalink
fp16 conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 13, 2019
1 parent 0782e4c commit 8e10079
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def build(self, func, target=None, target_host=None, params=None):
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Expand Down
35 changes: 32 additions & 3 deletions tests/python/relay/test_cpp_build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from tvm import relay
from tvm.contrib.nvcc import have_fp16

from tvm._ffi.function import _init_api
_init_api("tvm.relay.build_module")


def test_basic_build():
tgt = "llvm"
Expand Down Expand Up @@ -99,6 +96,38 @@ def test_fp16_build():
atol=1e-5, rtol=1e-5)


def test_fp16_conversion():
def check_conversion(tgt, ctx):
n = 10

for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
x = relay.var("x", relay.TensorType((n,), src))
y = x.astype(dst)
func = relay.Function([x], y)

# init input
X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
params = {"p0": X}

# build
with relay.build_config(opt_level=1):
g_json, mmod, params = relay.build(func, tgt, params=params)

# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("x", X)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)

np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
atol=1e-5, rtol=1e-5)

for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
check_conversion(target, ctx)


if __name__ == "__main__":
test_basic_build()
test_fp16_build()
test_fp16_conversion()

0 comments on commit 8e10079

Please sign in to comment.