From 0c341e4c49a4c95a0c4d68090a7d27759f5adf54 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Sep 2019 11:20:22 -0700 Subject: [PATCH] [ARITH] cleanup the indexmod/div on python side --- python/tvm/autotvm/task/task.py | 4 ++- python/tvm/expr.py | 20 +++++------ src/pass/rewrite_unsafe_select.cc | 6 ++-- tests/python/relay/test_op_level3.py | 14 ++++---- tests/python/relay/test_op_level5.py | 3 +- .../unittest/test_autotvm_flop_calculator.py | 6 ++-- tests/python/unittest/test_ir_builder.py | 5 +-- tests/python/unittest/test_lang_buffer.py | 28 +++++++-------- .../test_pass_rewrite_unsafe_select.py | 2 +- .../unittest/test_schedule_tensorize.py | 9 ++--- topi/python/topi/arm_cpu/conv2d.py | 35 ++++++++++++------- topi/python/topi/cuda/nms.py | 15 +++++--- topi/python/topi/cuda/sort.py | 21 ++++++----- topi/python/topi/nn/bitserial_conv2d.py | 18 +++++----- 14 files changed, 109 insertions(+), 77 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 901183f46948f..e0db27574898a 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -350,7 +350,9 @@ def _count_flop(exp): return _count_flop(exp.value) if isinstance(exp, expr.Var): return 0 - if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod, + if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, + expr.Div, expr.Mod, + expr.FloorDiv, expr.FloorMod, expr.Max, expr.Min, expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, expr.And, expr.Or, expr.Not)): diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a8bd651d64694..5b7c60d819bd2 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -72,23 +72,23 @@ def __rmul__(self, other): return _generic.multiply(other, self) def __div__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(self, other) def __rdiv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(other, self) def __truediv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(self, other) def __rtruediv__(self, other): - # if _dtype_is_int(self) and _dtype_is_int(other): - # raise div_ambiguity_error() + if _dtype_is_int(self) and _dtype_is_int(other): + raise div_ambiguity_error() return _generic.divide(other, self) def __floordiv__(self, other): @@ -100,8 +100,8 @@ def __rfloordiv__(self, other): return _generic.divide(other, self) def __mod__(self, other): - # raise div_ambiguity_error() - return _make._OpMod(self, other) + raise div_ambiguity_error() + # return _make._OpMod(self, other) def __neg__(self): neg_one = _api_internal._const(-1, self.dtype) diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 871efcae615d8..62db0b414be12 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } bool VisitExpr_(const Div* op) final { return BinaryOp(op); } bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); } bool VisitExpr_(const Min* op) final { return BinaryOp(op); } bool VisitExpr_(const Max* op) final { return BinaryOp(op); } bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 424462fbe0c42..2d92489328af8 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): yy = run_infer_type(y.astuple()) assert yy.checked_type == ret_type + idxd = tvm.indexdiv + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") axis = tvm.var("axis") verify_split((5, 5, 2, 2), 5, @@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None): axis=0) verify_split((d1, d2, d3, d4), 4, relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), - relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"), + relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])), axis=2) verify_split((d1, d2, d3, d4), 2, relay.ty.TupleType(tvm.convert([ - relay.ty.TensorType((d1/2, d2, d3, d4), "float32"), - relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])), + relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"), + relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])), axis=0) verify_split((d1, d2, d3, d4), (2, 4, 7), relay.ty.TupleType(tvm.convert([ diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index f4ac673cf3785..8c107351c81a4 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape): assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + idxd = tvm.indexdiv verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) - verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2)) + verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2))) def test_yolo_reorg(): def verify_yolo_reorg(shape, stride): diff --git a/tests/python/unittest/test_autotvm_flop_calculator.py b/tests/python/unittest/test_autotvm_flop_calculator.py index 54ade9a052678..5cafd02c45bf0 100644 --- a/tests/python/unittest/test_autotvm_flop_calculator.py +++ b/tests/python/unittest/test_autotvm_flop_calculator.py @@ -60,14 +60,14 @@ def test_pack_gemm(): k = tvm.reduce_axis((0, L)) bn = 4 - fld = tvm.floordiv - flm = tvm.floormod + idxd = tvm.indexdiv + idxm = tvm.indexmod A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) - C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)]) + C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)]) s = tvm.create_schedule([C.op]) assert compute_flop(s) == 2 * N * L * M diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index ef58174d4474f..c910c62424f0c 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -109,14 +109,15 @@ def test_gpu(): dtype = "float32" A = tvm.placeholder((n,), name='A') B = tvm.placeholder((n,), name='B') - fld = tvm.floordiv + idxd = tvm.indexdiv + def test_device_ir(A, B, C): n = A.shape[0] max_threads = 32 ib = tvm.ir_builder.create() bx = tvm.thread_axis("blockIdx.x") tx = tvm.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads)) + ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads)) ib.scope_attr(tx, "thread_extent", max_threads) idx = bx.var * max_threads + tx.var Aptr = ib.buffer_ptr(A) diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 9ad8b62821cfe..32c17452269ea 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod(): def assert_simplified_equal(index_simplified, index_direct): assert tvm.ir_pass.Equal(index_simplified, index_direct),\ "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod # Test Case1 index_simplified = A_stride.vload( - (idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1)) + (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1)) index_direct = A_stride.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case2 - index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1))) - index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s)))) + index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n), + idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1))) + index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 - index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + - idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + - idxmod(idxmod(k0, idxdiv(k1, s)), n))) + index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + + idxd(idxm(k0, idxd(k1, s)), n), + idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) + + idxm(idxm(k0, idxd(k1, s)), n))) index_direct = A.vload((0, k0)) assert_simplified_equal(index_simplified, index_direct) # Test Case4 (not able to simplify) - index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), - idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))) - index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n + - (idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))) + index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n), + idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))) + index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n + + (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))) assert_simplified_equal(index_simplified, index_direct) diff --git a/tests/python/unittest/test_pass_rewrite_unsafe_select.py b/tests/python/unittest/test_pass_rewrite_unsafe_select.py index b2d73ec00ce80..4c42899be62a0 100644 --- a/tests/python/unittest/test_pass_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_pass_rewrite_unsafe_select.py @@ -28,7 +28,7 @@ def test_rewrite_Select(): tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value - a = tvm.expr.Select(i>10, y, z) + a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z) aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" diff --git a/tests/python/unittest/test_schedule_tensorize.py b/tests/python/unittest/test_schedule_tensorize.py index 4bad959c2453d..59adf0cc7e994 100644 --- a/tests/python/unittest/test_schedule_tensorize.py +++ b/tests/python/unittest/test_schedule_tensorize.py @@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): # This tests whether algorithm and intrinsics expressions are simplified # as much as possible first and then checked for equality. See Issue #696 def test_tensorize_op(): - tdiv = tvm.truncdiv - tmod = tvm.truncmod + idxd = tvm.indexdiv + idxm = tvm.indexmod + def op_intrin(): bh = 9 bw = 9 x = tvm.placeholder((5, 5), name='A') y = tvm.compute((bh, bw), - lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)]) + lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)]) def intrin_func(ins, outs): xx, = ins @@ -239,7 +240,7 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(y.op, intrin_func) A = tvm.placeholder((5, 5), name='A') - B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)]) + B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)]) bt = op_intrin() s = tvm.create_schedule(B.op) diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 73a97d2bb33c9..f5cbbf0f7badc 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") + idxd = tvm.indexdiv + idxm = tvm.indexmod + r = KW m = tile_size alpha = m + r - 1 @@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt VK = cfg['tile_k'].size[-1] # pack input tile - input_tile = tvm.compute((C, P // VP, alpha, alpha, VP), + input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP), lambda c, b, eps, nu, bb: - data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps] - [(b*VP+bb) % nW * m + nu], + data_pad[idxd(b*VP + bb, nH*nW), c, + idxm(idxd(b*VP + bb, nW), nH) * m + eps, + idxm(b*VP + bb, nW) * m + nu], name='d') # transform kernel @@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt else: r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kw = tvm.reduce_axis((0, KW), 'r_kw') - U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk: + U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_nu = tvm.reduce_axis((0, alpha), 'r_nu') - V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb: + V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb: tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) * B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') # batch gemm c = tvm.reduce_axis((0, C), name='c') M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: - tvm.sum(U[eps][nu][k // VK][c][k % VK] * - V[eps][nu][b // VP][c][b % VP], axis=c), name='M') + tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] * + V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M') # inverse transform r_eps = tvm.reduce_axis((0, alpha), 'r_eps') @@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt # unpack output output = tvm.compute((N, K, H, W), lambda n, k, h, w: - Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], + Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m), + idxm(h, m), idxm(w, m)], name='output', tag='winograd_conv2d_output') # we have to manually assign effective GFLOP for winograd @@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): N, CI, H, W = get_const_tuple(data.shape) CO, _, KH, KW = get_const_tuple(kernel.shape) + idxd = tvm.indexdiv + if groups == 1: # query config of this workload workload = autotvm.task.args_to_workload( @@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data - new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) + new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) dispatch_ctx.update(target, new_workload, cfg) @@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], tile_size=tile_size) weight = F.reshape(weight, - newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) + newshape=(KH + tile_size - 1, + KW + tile_size - 1, + idxd(CO, VC), VC, CI)) weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) copy_inputs[1] = weight @@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data - new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), + new_weight = tvm.placeholder((KH + tile_size - 1, + KH + tile_size -1, + idxd(CO, VC), CI, VC), kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_weight, strides, padding, dilation, @@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): # Store the same config for the altered operator (workload) new_data = data CO, M, KH, KW = get_const_tuple(kernel.shape) - new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype) + new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 6ff8a79d36301..33fc7249802b6 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx new_range = num_anchors // elem_per_thread + 1 + idxd = tvm.indexdiv + idxm = tvm.indexmod # Scan: Downsweep: with ib. if_scope(tid < batch_size * num_anchors): - i = tid // num_anchors # number of batches - j = tid % num_anchors # number of anchors + i = idxd(tid, num_anchors) # number of batches + j = idxm(tid, num_anchors) # number of anchors with ib.if_scope(j < elem_per_thread): idx[tid] = idx_in[tid] with ib.else_scope(): - idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1] + idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1] return ib.get() @@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + idxd = tvm.indexdiv + idxm = tvm.indexmod + with ib.if_scope(tid < batch_size * num_anchors): - i = tid // num_anchors - j = tid % num_anchors + i = idxd(tid, num_anchors) + j = idxm(tid, num_anchors) base_idx = i * num_anchors * elem_length with ib.if_scope(flag[tid] > 0): with ib.for_range(0, elem_length) as k: diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index c45465e31624c..b02c14b47e60d 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) + idxd = tvm.indexdiv + idxm = tvm.indexmod with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_after) as j: @@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): base_idx = i * shape[axis] * axis_mul_after + j # OddEvenTransposeSort with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < (current_sort_num + 1) // 2): - offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): + offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after if is_ascend: - cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, values_out[offset] > values_out[offset + axis_mul_after]) else: - cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, + cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, values_out[offset] < values_out[offset + axis_mul_after]) with ib.if_scope(cond): temp_data[0] = values_out[offset] @@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + idxd = tvm.indexdiv + idxm = tvm.indexmod + with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_after) as j: current_sort_num = valid_count[i * axis_mul_after + j] @@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[base_idx + tid * axis_mul_after] = tid # OddEvenTransposeSort with ib.for_range(0, current_sort_num) as k: - with ib.if_scope(tid < (current_sort_num + 1) // 2): - offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): + offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after with ib.if_scope(tvm.all(is_ascend == 1, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ + 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ data[offset] > data[offset + axis_mul_after])): temp_data[0] = data[offset] data[offset] = data[offset + axis_mul_after] @@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] with ib.if_scope(tvm.all(is_ascend == 0, \ - 2 * tid + (k % 2) + 1 < current_sort_num, \ + 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ data[offset] < data[offset + axis_mul_after])): temp_data[0] = data[offset] data[offset] = data[offset + axis_mul_after] diff --git a/topi/python/topi/nn/bitserial_conv2d.py b/topi/python/topi/nn/bitserial_conv2d.py index 2faabf2bbf897..932c141450acb 100644 --- a/topi/python/topi/nn/bitserial_conv2d.py +++ b/topi/python/topi/nn/bitserial_conv2d.py @@ -313,13 +313,14 @@ def _conv(n, co, h, w, vh, vw, vc): axis=[ci, dh, dw, b1, b2]) conv = tvm.compute(ovshape, _conv, name='conv_out') - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod return tvm.compute( oshape, lambda n, co, h, w: - conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv( - w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + conv[n, + idxd(co, VC), idxd(h, VH), idxd(w, VW), + idxm(h, VH), idxm(w, VW), idxm(co, VC)], name='conv_vec', tag='spatial_bitserial_conv_nchw') @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') @@ -419,12 +420,13 @@ def _conv(n, h, w, co, vh, vw, vc): conv = tvm.compute(ovshape, _conv, name='conv') - idxdiv = tvm.indexdiv - idxmod = tvm.indexmod + idxd = tvm.indexdiv + idxm = tvm.indexmod return tvm.compute( oshape, lambda n, h, w, co: - conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv( - co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], + conv[n, + idxd(h, VH), idxd(w, VW), idxd(co, VC), + idxm(h, VH), idxm(w, VW), idxm(co, VC)], name='output_unpack', tag='spatial_bitserial_conv_nhwc') @tvm.target.generic_func