Skip to content

Commit

Permalink
Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dpankratz authored Mar 12, 2020
1 parent ec86d7f commit 173b4fc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
}
if (pa) {
Expand Down Expand Up @@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
}
if (pa) {
Expand Down
15 changes: 15 additions & 0 deletions src/tir/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,34 +606,49 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
}

PrimExpr floor(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
return tir::CallNode::make(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic);
}

PrimExpr ceil(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
return tir::CallNode::make(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic);
}

PrimExpr round(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return tir::CallNode::make(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic);
}

PrimExpr nearbyint(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return tir::CallNode::make(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic);
}

PrimExpr trunc(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
Expand Down
19 changes: 17 additions & 2 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,14 @@ def test_bitwise():
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"


def test_float_bitwise():
t = tvm.tir.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs,
lambda lhs, rhs : lhs & rhs
]:
lambda lhs, rhs : lhs & rhs]:
try:
test(t,10.0)
assert False
Expand All @@ -206,6 +206,20 @@ def test_float_bitwise():
except RuntimeError:
pass


def test_divide_by_zero():
for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
try:
test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
assert False
except tvm.TVMError:
pass


def test_isnan():
x = te.var('x', 'float32')
assert str(tvm.tir.isnan(x)) == 'isnan(x)'
Expand Down Expand Up @@ -250,6 +264,7 @@ def test_equality_string_imm():
test_all()
test_bitwise()
test_float_bitwise()
test_divide_by_zero()
test_isnan()
test_equality()
test_equality_string_imm()
11 changes: 11 additions & 0 deletions tests/python/unittest/test_tvm_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def test_nearbyint():
tvm.testing.assert_allclose(
a_rounded.asnumpy(), np.rint(a.asnumpy()))

def test_round_intrinsics_on_int():
i = tvm.te.var("i", 'int32')
for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
tvm.tir.floor, tvm.tir.nearbyint]:
assert op(tvm.tir.const(10,'int32')).value == 10
assert op(tvm.tir.const(True,'bool')).value == True
assert op(i).same_as(i)

assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False


def test_unary_intrin():
test_funcs = [
Expand Down Expand Up @@ -75,3 +85,4 @@ def run_test(tvm_intrin, np_func):
if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()

0 comments on commit 173b4fc

Please sign in to comment.