Skip to content

Commit

Permalink
[ARITH] cleanup the indexmod/div on python side
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 27, 2019
1 parent 368a4ae commit 0c341e4
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 77 deletions.
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ def _count_flop(exp):
return _count_flop(exp.value)
if isinstance(exp, expr.Var):
return 0
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
expr.Div, expr.Mod,
expr.FloorDiv, expr.FloorMod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,23 @@ def __rmul__(self, other):
return _generic.multiply(other, self)

def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __truediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rtruediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __floordiv__(self, other):
Expand All @@ -100,8 +100,8 @@ def __rfloordiv__(self, other):
return _generic.divide(other, self)

def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other)
raise div_ambiguity_error()
# return _make._OpMod(self, other)

def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
Expand Down
6 changes: 4 additions & 2 deletions src/pass/rewrite_unsafe_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
Expand Down
14 changes: 8 additions & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
yy = run_infer_type(y.astuple())
assert yy.checked_type == ret_type

idxd = tvm.indexdiv

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
Expand All @@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
axis=0)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), 2,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
axis=0)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape):
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))

def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_autotvm_flop_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L))

bn = 4
fld = tvm.floordiv
flm = tvm.floormod
idxd = tvm.indexdiv
idxm = tvm.indexmod

A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])

s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ def test_gpu():
dtype = "float32"
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
idxd = tvm.indexdiv

def test_device_ir(A, B, C):
n = A.shape[0]
max_threads = 32
ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A)
Expand Down
28 changes: 14 additions & 14 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Test Case1
index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)

# Test Case2
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxd(idxm(k0, idxd(k1, s)), n),
idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxm(idxm(k0, idxd(k1, s)), n)))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
(idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
assert_simplified_equal(index_simplified, index_direct)


Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_pass_rewrite_unsafe_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_rewrite_Select():
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value

a = tvm.expr.Select(i>10, y, z)
a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor):
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
idxd = tvm.indexdiv
idxm = tvm.indexmod

def op_intrin():
bh = 9
bw = 9
x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])

def intrin_func(ins, outs):
xx, = ins
Expand All @@ -239,7 +240,7 @@ def intrin_func(ins, outs):
return tvm.decl_tensor_intrin(y.op, intrin_func)

A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
bt = op_intrin()
s = tvm.create_schedule(B.op)

Expand Down
35 changes: 23 additions & 12 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

idxd = tvm.indexdiv
idxm = tvm.indexmod

r = KW
m = tile_size
alpha = m + r - 1
Expand All @@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
VK = cfg['tile_k'].size[-1]

# pack input tile
input_tile = tvm.compute((C, P // VP, alpha, alpha, VP),
input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP),
lambda c, b, eps, nu, bb:
data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
[(b*VP+bb) % nW * m + nu],
data_pad[idxd(b*VP + bb, nH*nW), c,
idxm(idxd(b*VP + bb, nW), nH) * m + eps,
idxm(b*VP + bb, nW) * m + nu],
name='d')

# transform kernel
Expand All @@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
else:
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

# transform image
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')

# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // VK][c][k % VK] *
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')
tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] *
V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M')

# inverse transform
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
Expand All @@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt

# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m),
idxm(h, m), idxm(w, m)],
name='output', tag='winograd_conv2d_output')

# we have to manually assign effective GFLOP for winograd
Expand Down Expand Up @@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)

idxd = tvm.indexdiv

if groups == 1:
# query config of this workload
workload = autotvm.task.args_to_workload(
Expand All @@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):

# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)
Expand All @@ -553,15 +560,19 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])

copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size

# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
new_weight = tvm.placeholder((KH + tile_size - 1,
KH + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
Expand Down Expand Up @@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload)
new_data = data
CO, M, KH, KW = get_const_tuple(kernel.shape)
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
depthwise_conv2d_nchw)
Expand Down
15 changes: 10 additions & 5 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
new_range = num_anchors // elem_per_thread + 1
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors # number of batches
j = tid % num_anchors # number of anchors
i = idxd(tid, num_anchors) # number of batches
j = idxm(tid, num_anchors) # number of anchors
with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid]
with ib.else_scope():
idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1]
idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]

return ib.get()

Expand Down Expand Up @@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx

idxd = tvm.indexdiv
idxm = tvm.indexmod

with ib.if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors
j = tid % num_anchors
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0):
with ib.for_range(0, elem_length) as k:
Expand Down
Loading

0 comments on commit 0c341e4

Please sign in to comment.