Skip to content

Commit

Permalink
[RELAY][Fix] i64 indices (apache#5235)
Browse files Browse the repository at this point in the history
* fix

* resolve comments
  • Loading branch information
meta-project-ci authored and Trevor Morris committed Sep 2, 2020
1 parent 14e2ba6 commit f226005
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/te/schedule/operation_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class OperationInliner final : public StmtExprMutator {
} else {
Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->indices[i]);
// cast indices to the type of the original indexing variable
vmap.Set(args_[i], cast(args_[i].dtype(), op->indices[i]));
}
expr = Substitute(Evaluate(expr), vmap).as<EvaluateNode>()->value;
}
Expand Down
75 changes: 75 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,79 @@ def expected():
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)


def test_fuse_take():
"""Test fusion case involving concat and take"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(1, "int64"),)
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2,
dtype="int64")
c = relay.const([0], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.op.take(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


def test_fuse_gather_nd():
"""Test fusion case involving concat and gather_nd"""

def before():
shape = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
x = relay.var("x", shape=shape)
concat = relay.concatenate([x,x], axis=-1)
out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
return relay.Function(relay.analysis.free_vars(out), out)

def expected():
shape1 = (tvm.tir.const(10, "int64"),
tvm.tir.const(1, "int64"))
shape2 = (tvm.tir.const(2, "int64"),
tvm.tir.const(2, "int64"))
x = relay.var("x", shape=shape1)
p0 = relay.var("p0", shape=shape1)
p1 = relay.var("p1", shape=shape2, dtype="int64")
c = relay.const([[0,1],[1,0]], dtype="int64")
concat = relay.concatenate([p0,p0], axis=-1)
out = relay.gather_nd(concat, indices=p1)

f0 = relay.Function([p0, p1], out)
f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

y = relay.Call(f0, [x, c])
return relay.Function([x], y)

orig = before()
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(m["main"], after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -637,3 +710,5 @@ def expected():
test_immutable()
test_split()
test_fuse_max()
test_fuse_take()
test_fuse_gather_nd()

0 comments on commit f226005

Please sign in to comment.