Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions src/relax/backend/vm/lower_runtime_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
Expand Down Expand Up @@ -112,22 +114,15 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->struct_info_.defined());
auto arg = call_node->args[1];
CHECK(arg->IsInstance<ShapeExprNode>() || arg->IsInstance<VarNode>())
<< "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound "
"to a ShapeExpr";

if (arg->IsInstance<ShapeExprNode>()) {
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
} else {
// Handling the case when arg is VarNode
Optional<Expr> _bound_val = LookupBinding(Downcast<Var>(arg));
ICHECK(_bound_val.defined());
Expr bound_val = _bound_val.value();
CHECK(bound_val->IsInstance<ShapeExprNode>())
<< "VMBuiltinLower expects bound value to be a ShapeExpr";
return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(),
{GetStructInfo(call_node)});
}

CHECK(arg->struct_info_->IsInstance<ShapeStructInfoNode>())
<< "TypeError: "
<< "VMBuiltinLower expects the shape arg of R.reshape "
<< "to be a ShapeExpr or VarNode bound to a ShapeExpr. "
<< "However, in expression " << call_node << ", the shape argument " << arg
<< " has struct info " << arg->struct_info_;

return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
Expand All @@ -136,6 +131,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr TensorToShape(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());

return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ToDevice(const Call& call_node) {
// TODO(yongwww): replace ToVDeviceAttrs with related Expr
ICHECK(call_node->args.size() == 1);
Expand Down Expand Up @@ -194,6 +196,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand All @@ -211,6 +214,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relax/test_vm_builtin_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,70 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
relax.transform.LowerRuntimeBuiltin()(Before)


def test_vm_reshape_may_be_var():
"""R.reshape does not require an in-line R.shape"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32"), shape: R.Shape):
R.func_attr({"relax.force_pure": True})
reshape = R.reshape(A, shape)
return reshape

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32"), shape: R.Shape):
R.func_attr({"relax.force_pure": True})
reshape = R.call_packed(
"vm.builtin.reshape",
A,
shape,
sinfo_args=R.Tensor(shape, dtype="float32"),
)
return reshape

After = relax.transform.VMBuiltinLower()(Before)

tvm.ir.assert_structural_equal(Expected, After)


def test_vm_reshape_using_tensor_to_shape():
"""Shape argument of R.reshape may come from tensor_to_shape"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")):
R.func_attr({"relax.force_pure": True})
shape = R.tensor_to_shape(shape_tensor)
reshape = R.reshape(A, shape)
return reshape

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")):
R.func_attr({"relax.force_pure": True})

shape = R.call_packed(
"vm.builtin.tensor_to_shape",
shape_tensor,
sinfo_args=R.Shape(ndim=2),
)
reshape = R.call_packed(
"vm.builtin.reshape",
A,
shape,
sinfo_args=R.Tensor(shape, dtype="float32"),
)
return reshape

After = relax.transform.VMBuiltinLower()(Before)

tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()