Skip to content

Commit

Permalink
Fix relay i64 error
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 4, 2020
1 parent 6e1cd82 commit 318fea3
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
for (IndexExpr val : shape) {
const int64_t* pval = tir::as_const_int(val);
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
res.push_back(IntImm(DataType::Int(32), *pval));
// CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
// CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
// res.push_back(IntImm(DataType::Int(32), *pval));
res.push_back(val);
} else if (val->IsInstance<tir::AnyNode>()) {
res.push_back(val.as<tir::AnyNode>()->ToVar());
} else {
Expand Down
115 changes: 115 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import run_opt_pass
import numpy as np


def test_fuse_simple():
Expand Down Expand Up @@ -621,6 +622,117 @@ def expected():
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)


def test_fuse_strided_slice():
"""Test fusion case involving concat and strided_slice"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)])
t = relay.Function(relay.analysis.free_vars(out), out)
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
p0 = relay.var("p0", shape=shape)
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)])

f0 = relay.Function([p0], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x])
return relay.Function([x], y)
orig = before()
fuse0(tvm.IRModule.from_expr(orig))
t = tvm.IRModule.from_expr(orig)
m = fuse2(tvm.IRModule.from_expr(orig))
attention = m["main"].body.op.params

relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_take():
"""Test fusion case involving concat and take"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(1, "int64"),)
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2,
dtype="int64")
c = relay.const([0], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.op.take(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
fuse0(tvm.IRModule.from_expr(orig))
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_gather_nd():
"""Test fusion case involving concat and gather_nd"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(2, "int64"),
tvm.tir.const(2, "int64"))
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2, dtype="int64")
c = relay.const([[0,1],[1,0]], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.gather_nd(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
fuse0(tvm.IRModule.from_expr(orig))
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -637,3 +749,6 @@ def expected():
test_immutable()
test_split()
test_fuse_max()
test_fuse_strided_slice()
test_fuse_take()
test_fuse_gather_nd()

0 comments on commit 318fea3

Please sign in to comment.