From 0ea564ea4dd4384cfdb0516c2e0bc919650867a2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Jul 2024 10:22:01 -0500 Subject: [PATCH] [Relax] Allow dynamic shape argument to R.reshape Prior to this commit, the `shape` argument to `R.reshape` was required to either be an in-line `relax::ShapeExpr`, or a variable that had been bound to a `relax::ShapeExpr` within the current function. As a result, shapes that were provided as function arguments or that were produced by another operation (e.g. `R.tensor_to_shape`) would unnecessarily trigger an error. This commit updates the `VMBuiltinLower` pass to instead check that the argument has `relax::ShapeStructInfo`. Closes https://github.com/apache/tvm/issues/17217 --- src/relax/backend/vm/lower_runtime_builtin.cc | 36 +++++----- tests/python/relax/test_vm_builtin_lower.py | 65 +++++++++++++++++++ 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index a3867ae92448..4757561b549b 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -49,6 +49,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); + } else if (call->op == tensor_to_shape_op_) { + return TensorToShape(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -112,22 +114,15 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { ICHECK(call_node->args.size() == 2); ICHECK(call_node->struct_info_.defined()); auto arg = call_node->args[1]; - CHECK(arg->IsInstance() || arg->IsInstance()) - << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound " - "to a ShapeExpr"; - - if (arg->IsInstance()) { - return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); - } else { - // Handling the case when arg is VarNode - Optional _bound_val = LookupBinding(Downcast(arg)); - ICHECK(_bound_val.defined()); - Expr bound_val = _bound_val.value(); - CHECK(bound_val->IsInstance()) - << "VMBuiltinLower expects bound value to be a ShapeExpr"; - return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), - {GetStructInfo(call_node)}); - } + + CHECK(arg->struct_info_->IsInstance()) + << "TypeError: " + << "VMBuiltinLower expects the shape arg of R.reshape " + << "to be a ShapeExpr or VarNode bound to a ShapeExpr. " + << "However, in expression " << call_node << ", the shape argument " << arg + << " has struct info " << arg->struct_info_; + + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } Expr ShapeOf(const Call& call_node) { @@ -136,6 +131,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr TensorToShape(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + + return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr ICHECK(call_node->args.size() == 1); @@ -194,6 +196,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -211,6 +214,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; + const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index 984f9f958ca2..daa59793cc47 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -82,5 +82,70 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: relax.transform.LowerRuntimeBuiltin()(Before) +def test_vm_reshape_may_be_var(): + """R.reshape does not require an in-line R.shape""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32"), shape: R.Shape): + R.func_attr({"relax.force_pure": True}) + reshape = R.reshape(A, shape) + return reshape + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32"), shape: R.Shape): + R.func_attr({"relax.force_pure": True}) + reshape = R.call_packed( + "vm.builtin.reshape", + A, + shape, + sinfo_args=R.Tensor(shape, dtype="float32"), + ) + return reshape + + After = relax.transform.VMBuiltinLower()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_vm_reshape_using_tensor_to_shape(): + """Shape argument of R.reshape may come from tensor_to_shape""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): + R.func_attr({"relax.force_pure": True}) + shape = R.tensor_to_shape(shape_tensor) + reshape = R.reshape(A, shape) + return reshape + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): + R.func_attr({"relax.force_pure": True}) + + shape = R.call_packed( + "vm.builtin.tensor_to_shape", + shape_tensor, + sinfo_args=R.Shape(ndim=2), + ) + reshape = R.call_packed( + "vm.builtin.reshape", + A, + shape, + sinfo_args=R.Tensor(shape, dtype="float32"), + ) + return reshape + + After = relax.transform.VMBuiltinLower()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()