Skip to content

Commit

Permalink
[ARITH] cleanup the indexmod/div on python side (#4028)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 28, 2019
1 parent bbf82e0 commit f98035b
Show file tree
Hide file tree
Showing 24 changed files with 188 additions and 116 deletions.
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ def _count_flop(exp):
return _count_flop(exp.value)
if isinstance(exp, expr.Var):
return 0
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
expr.Div, expr.Mod,
expr.FloorDiv, expr.FloorMod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,23 @@ def __rmul__(self, other):
return _generic.multiply(other, self)

def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __truediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)

def __rtruediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)

def __floordiv__(self, other):
Expand All @@ -100,8 +100,8 @@ def __rfloordiv__(self, other):
return _generic.divide(other, self)

def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other)
raise div_ambiguity_error()
# return _make._OpMod(self, other)

def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
Expand Down
6 changes: 4 additions & 2 deletions src/pass/rewrite_unsafe_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
Expand Down
14 changes: 8 additions & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
yy = run_infer_type(y.astuple())
assert yy.checked_type == ret_type

idxd = tvm.indexdiv

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
Expand All @@ -393,15 +395,15 @@ def verify_split(dshape, indices_or_sections, ret_type, axis=None):
axis=0)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), 2,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
axis=0)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,9 @@ def verify_yolo_reorg(shape, stride, out_shape):
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))

def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_autotvm_flop_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L))

bn = 4
fld = tvm.floordiv
flm = tvm.floormod
idxd = tvm.indexdiv
idxm = tvm.indexmod

A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])

s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def check_cuda(dtype, n, lanes):
print("skip because gpu does not support int8")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
Expand Down Expand Up @@ -165,9 +165,10 @@ def test_cuda_shuffle():
print("skip because cuda is not enabled..")
return

idxm = tvm.indexmod
a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
sch = tvm.create_schedule(c.op)
x = c.op.axis[0]
xo, xi = sch[c].split(x, 4)
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ def test_gpu():
dtype = "float32"
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
idxd = tvm.indexdiv

def test_device_ir(A, B, C):
n = A.shape[0]
max_threads = 32
ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A)
Expand Down
28 changes: 14 additions & 14 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Test Case1
index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)

# Test Case2
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxd(idxm(k0, idxd(k1, s)), n),
idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxm(idxm(k0, idxd(k1, s)), n)))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
(idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
assert_simplified_equal(index_simplified, index_direct)


Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_pass_rewrite_unsafe_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_rewrite_Select():
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value

a = tvm.expr.Select(i>10, y, z)
a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
Expand Down
9 changes: 5 additions & 4 deletions tests/python/unittest/test_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor):
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
idxd = tvm.indexdiv
idxm = tvm.indexmod

def op_intrin():
bh = 9
bw = 9
x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])

def intrin_func(ins, outs):
xx, = ins
Expand All @@ -239,7 +240,7 @@ def intrin_func(ins, outs):
return tvm.decl_tensor_intrin(y.op, intrin_func)

A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
bt = op_intrin()
s = tvm.create_schedule(B.op)

Expand Down
14 changes: 11 additions & 3 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)

idxd = tvm.indexdiv
idxm = tvm.indexmod

# Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8
Expand Down Expand Up @@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')

kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
idxm = tvm.indexmod
if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])

N, H, W, IB, CI = data_q.shape
Expand Down Expand Up @@ -147,8 +151,12 @@ def _unipolar_conv(n, h, w, co, vh, vw, vc):
else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')

conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),

conv = tvm.compute(oshape,
lambda n, h, w, co:
conv_vec[n,
idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc')

return conv
Expand Down
Loading

0 comments on commit f98035b

Please sign in to comment.