Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/master' into leezu-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Dec 2, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents eed98cb + af26e92 commit d5aa724
Showing 15 changed files with 111 additions and 97 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 9bd2c7 to efdac9
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -510,6 +510,11 @@ endif()
FILE(GLOB_RECURSE SOURCE "src/*.cc" "src/*.h" "include/*.h")
FILE(GLOB_RECURSE CUDA "src/*.cu" "src/*.cuh")

if(MSVC)
FILE(GLOB_RECURSE TVM_BRIDGE_SOURCE "src/*/tvm_bridge.cc")
list(REMOVE_ITEM SOURCE ${TVM_BRIDGE_SOURCE})
endif()

if(NOT USE_INTGEMM)
FILE(GLOB_RECURSE INTGEMM_OPERATOR_SOURCE "src/operator/contrib/intgemm/*.cc" "src/operator/contrib/intgemm/*.h")
list(REMOVE_ITEM SOURCE ${INTGEMM_OPERATOR_SOURCE})
@@ -865,12 +870,11 @@ function(BuildTVMOP)
include(cmake/BuildTVM.cmake)
add_subdirectory("3rdparty/tvm")
set_target_properties(tvm PROPERTIES CXX_CLANG_TIDY "") # don't lint 3rdparty dependency
set_target_properties(tvm_topi PROPERTIES CXX_CLANG_TIDY "") # don't lint 3rdparty dependency
set_target_properties(tvm_runtime PROPERTIES CXX_CLANG_TIDY "") # don't lint 3rdparty dependency
endfunction()

if(USE_TVM_OP)
list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
list(APPEND mxnet_LINKER_LIBS tvm_runtime)
BuildTVMOP()
find_package(Python3 REQUIRED)
set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}" "--config" "${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf" "-L" "${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm")
2 changes: 1 addition & 1 deletion ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
@@ -1054,7 +1054,7 @@ nightly_test_large_tensor() {
set -ex
export PYTHONPATH=./python/
export DMLC_LOG_STACK_TRACE_DEPTH=10
pytest --timeout=0 tests/nightly/test_np_large_array.py
pytest --timeout=0 --forked tests/nightly/test_np_large_array.py
}

#Tests Model backwards compatibility on MXNet
6 changes: 3 additions & 3 deletions cmake/BuildTVM.cmake
Original file line number Diff line number Diff line change
@@ -85,9 +85,9 @@ set(USE_LLVM ON)
set(USE_BLAS none)

# /path/to/mkl: mkl root path when use mkl blas library
# set(USE_MKL_PATH /opt/intel/mkl) for UNIX
# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32
set(USE_MKL_PATH none)
# set(USE_MKL /opt/intel/mkl) for UNIX
# set(USE_MKL ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32
set(USE_MKL OFF)

# Whether use contrib.random in runtime
set(USE_RANDOM OFF)
76 changes: 38 additions & 38 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
@@ -21,11 +21,11 @@
from .. import assign_by_req, reduce_axes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='A', dtype=dtype)
B = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='B', dtype=dtype)
C = tvm.compute([tvm.size_var() for _ in range(ndim)],
A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
B = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='B', dtype=dtype)
C = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] + B[index], name='C')
s = tvm.create_schedule(C.op)
s = tvm.te.create_schedule(C.op)
return s, A, B, C


@@ -44,12 +44,12 @@ def vadd(dtype, ndim):
dtype=["float32", "float64"], ndim=[5])
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
s = tvm.te.create_schedule(C.op)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
bx, tx = s[C].split(fused, factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B, C]


@@ -62,12 +62,12 @@ def compute_backward_vadd(dtype, ndim, reduce1st, req):
# They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
# of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
lambda t: tvm.const(0, dtype=t), name="sum")
X = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.te.comm_reducer(lambda x, y: x + y,
lambda t: tvm.tir.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.create_schedule(in_grad.op)
s = tvm.te.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]


@@ -90,8 +90,8 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
@@ -101,15 +101,15 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):


def compute_degandrad(dtype, ndim, n):
A = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='A', dtype=dtype)
A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
import math
if n == 0:
B = tvm.compute([tvm.size_var() for _ in range(ndim)],
lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype), name='B')
else:
B = tvm.compute([tvm.size_var() for _ in range(ndim)],
lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B')
s = tvm.create_schedule(B.op)
B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype), name='B')
s = tvm.te.create_schedule(B.op)
return s, A, B


@@ -137,43 +137,43 @@ def rad2deg(dtype, ndim):
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 0)
s = tvm.create_schedule(B.op)
s = tvm.te.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B]


@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
s = tvm.create_schedule(B.op)
s = tvm.te.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B]


def compute_backward_degandrad(dtype, ndim, req, n):
ishape = [tvm.size_var() for _ in range(ndim)]
in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype)
ishape = [tvm.te.size_var() for _ in range(ndim)]
in_grad_tmp = tvm.te.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.te.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.te.placeholder(ishape, name='out_grad', dtype=dtype)
import math
if n == 0:
ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
ret = tvm.te.compute(ishape, lambda *index: out_grad[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype))
else:
ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype))
ret = tvm.te.compute(ishape, lambda *index: out_grad[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype))
if (req == "kAddTo"):
in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
in_grad = tvm.te.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
else:
in_grad = tvm.compute(ishape, lambda *index: ret[index])
s = tvm.create_schedule(in_grad.op)
in_grad = tvm.te.compute(ishape, lambda *index: ret[index])
s = tvm.te.create_schedule(in_grad.op)
return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad]


@@ -208,8 +208,8 @@ def cuda_backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
@@ -225,8 +225,8 @@ def cuda_backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
num_thread = 64
for t in c_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
22 changes: 13 additions & 9 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
@@ -125,23 +125,27 @@ def get_cuda_arch(arch):
help="Path which stores the config file")
arguments = parser.parse_args()

func_list_llvm = []
func_list_cuda = []
mod_llvm = tvm.IRModule({})
mod_cuda = tvm.IRModule({})
has_cuda = False

# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args, name in operator_def.invoke_all():
name = operator_def.get_op_name(name, args)
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
if tvm.runtime.module.enabled(get_target(operator_def.target)):
func_lower = tvm.lower(sch, args,
name=name,
binds=operator_def.get_binds(args))
func_list.append(func_lower)

lowered_funcs = {get_target("cpu"): func_list_llvm}
if len(func_list_cuda) > 0:
lowered_funcs[get_target("cuda")] = func_list_cuda
if operator_def.target == "cpu":
mod = mod_llvm.update(func_lower)
else:
has_cuda = True
mod_cuda.update(func_lower)

lowered_funcs = {get_target("cpu"): mod_llvm}
if has_cuda > 0:
lowered_funcs[get_target("cuda")] = mod_cuda
cuda_arch = get_cuda_arch(arguments.cuda_arch)
if cuda_arch is None:
logging.info('No cuda arch specified. TVM will try to detect it from the build platform.')
10 changes: 5 additions & 5 deletions contrib/tvmop/core/fromnumeric.py
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@

def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=itype)
reduce_output = reduce_axes(a, axes, tvm.sum, otype)
a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='a', dtype=itype)
reduce_output = reduce_axes(a, axes, tvm.tir.sum, otype)
output_placeholder, final_output = assign_by_req(reduce_output, req)
s = tvm.create_schedule(final_output.op)
s = tvm.te.create_schedule(final_output.op)
return s, a, output_placeholder, final_output, [reduce_output, final_output]


@@ -53,8 +53,8 @@ def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
itype, otype, ndim, reduce1st_dim, req)
num_threads = 64
for t in tensor_list:
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_threads)
18 changes: 9 additions & 9 deletions contrib/tvmop/core/multiarray.py
Original file line number Diff line number Diff line change
@@ -25,9 +25,9 @@ def compute_dot(A, B):
M = A.shape[0]
K = A.shape[1]
N = B.shape[1]
k = tvm.reduce_axis((0, K), 'k')
C = tvm.compute((M, N),
lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),
k = tvm.te.reduce_axis((0, K), 'k')
C = tvm.te.compute((M, N),
lambda x, y: tvm.tir.sum(A[x, k] * B[k, y], axis=k),
name='C')
return C

@@ -37,13 +37,13 @@ def dot(dtype, fallback):
cfg = autotvm.get_config()
cfg.define_knob("bn", [64] if fallback else [64, 32])
cfg.define_knob("factor", [4] if fallback else [4])
M = tvm.size_var("M")
K = tvm.size_var("K")
N = tvm.size_var("N")
A = tvm.placeholder((M, K), name='A', dtype=dtype)
B = tvm.placeholder((K, N), name='B', dtype=dtype)
M = tvm.te.size_var("M")
K = tvm.te.size_var("K")
N = tvm.te.size_var("N")
A = tvm.te.placeholder((M, K), name='A', dtype=dtype)
B = tvm.te.placeholder((K, N), name='B', dtype=dtype)
C = compute_dot(A, B)
s = tvm.create_schedule(C.op)
s = tvm.te.create_schedule(C.op)
# Blocking by loop tiling
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val)
k, = s[C].op.reduce_axis
32 changes: 16 additions & 16 deletions contrib/tvmop/core/umath.py
Original file line number Diff line number Diff line change
@@ -25,18 +25,18 @@
'less': lambda a, b, *idx: a[idx] < b[idx],
'greater_equal': lambda a, b, *idx: a[idx] >= b[idx],
'less_equal': lambda a, b, *idx: a[idx] <= b[idx],
'logical_and': lambda a, b, *idx: tvm.all(a[idx] != 0, b[idx] != 0),
'logical_or': lambda a, b, *idx: tvm.any(a[idx] != 0, b[idx] != 0),
'logical_xor': lambda a, b, *idx: tvm.all(tvm.any(a[idx] != 0, b[idx] != 0), tvm.any(a[idx] == 0, b[idx] == 0)),
'logical_and': lambda a, b, *idx: tvm.tir.all(a[idx] != 0, b[idx] != 0),
'logical_or': lambda a, b, *idx: tvm.tir.any(a[idx] != 0, b[idx] != 0),
'logical_xor': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx] != 0, b[idx] != 0), tvm.tir.any(a[idx] == 0, b[idx] == 0)),
}


def _compute_binary_logic(op, dtype, ndim):
a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='a')
b = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='b')
c = tvm.compute([tvm.size_var() for _ in range(ndim)],
a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='a')
b = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='b')
c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c')
s = tvm.create_schedule(c.op)
s = tvm.te.create_schedule(c.op)
return s, a, b, c


@@ -70,8 +70,8 @@ def _binary_logic_gpu(compute_func, op, itype, ndim):
axes = [axis for axis in c.op.axis]
fused = s[c].fuse(*axes)
bx, tx = s[c].split(fused, factor=64)
s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
s[c].bind(bx, tvm.te.thread_axis('blockIdx.x'))
s[c].bind(tx, tvm.te.thread_axis('threadIdx.x'))
return s, [a, b, c]


@@ -90,18 +90,18 @@ def _binary_logic_gpu(compute_func, op, itype, ndim):
'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b,
'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b,
'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b,
'logical_and_scalar': lambda a, b, *idx: tvm.all(a[idx].astype(b.dtype) != 0 , b != 0),
'logical_or_scalar': lambda a, b, *idx: tvm.any(a[idx].astype(b.dtype) != 0, b != 0),
'logical_xor_scalar': lambda a, b, *idx: tvm.all(tvm.any(a[idx].astype(b.dtype) != 0, b != 0), tvm.any(a[idx].astype(b.dtype) == 0, b == 0)),
'logical_and_scalar': lambda a, b, *idx: tvm.tir.all(a[idx].astype(b.dtype) != 0 , b != 0),
'logical_or_scalar': lambda a, b, *idx: tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0),
'logical_xor_scalar': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0), tvm.tir.any(a[idx].astype(b.dtype) == 0, b == 0)),
}


def _compute_binary_scalar_logic(op, dtype, ndim):
a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=dtype)
b = tvm.var('b', dtype='float64')
c = tvm.compute([tvm.size_var() for _ in range(ndim)],
a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='a', dtype=dtype)
b = tvm.te.var('b', dtype='float64')
c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c')
s = tvm.create_schedule(c.op)
s = tvm.te.create_schedule(c.op)
return s, a, b, c


2 changes: 1 addition & 1 deletion contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
@@ -116,7 +116,7 @@ def get_config_spaces(self):

def get_binds(self, args):
if self.auto_broadcast:
return {arg: tvm.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
return {arg: tvm.tir.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
for arg in args}
return None

Loading

0 comments on commit d5aa724

Please sign in to comment.