From 0c1bb250a7af9651cafb5b85ff3bd337841dbf8c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 26 Oct 2023 09:41:15 -0500 Subject: [PATCH] [Unity] Move VarBinding's and MatchCast's value to base class Prior to this commit, the `VarBinding` and `MatchCast` had equivalent `Expr value` fields. For use cases that need the bound value (e.g. collecting known values), this required downcasting to each subclass. This commit moves the `Expr value` to the base `Binding` class, removing the need for the downcasting. No impact to either C++ or Python usage is expected. On the C++ side, all use of `VarBindingNode::value` or `MatchCastNode::value` will now access `BindingNode::value`. On the Python side, all dynamic access of `binding_obj.value` will access the new field. --- include/tvm/relax/expr.h | 9 +++--- src/relax/backend/vm/codegen_vm.cc | 9 +----- src/relax/backend/vm/codegen_vm_tir.cc | 9 +----- src/relax/ir/block_builder.cc | 31 +++++++------------- src/relax/ir/expr_functor.cc | 3 +- src/relax/transform/canonicalize_bindings.cc | 15 ++++------ src/relax/transform/fuse_ops.cc | 8 +---- src/relax/utils.cc | 12 ++++---- 8 files changed, 29 insertions(+), 67 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 02d6f8d2767e..311c889bf24d 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -722,6 +722,10 @@ class BindingNode : public Object { public: /*! \brief The return variable to bound to. */ Var var; + + /*! \brief The binding value. */ + Expr value; + mutable Span span; static constexpr const char* _type_key = "relax.expr.Binding"; @@ -751,8 +755,6 @@ class Binding : public ObjectRef { */ class MatchCastNode : public BindingNode { public: - /*! \brief The input value to match cast. */ - Expr value; /*! \brief The struct info pattern to match to. */ StructInfo struct_info; @@ -796,9 +798,6 @@ class MatchCast : public Binding { class VarBindingNode : public BindingNode { public: - /*! \brief The binding value. */ - Expr value; - void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index caee0a0c13d6..8d07881c387d 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -122,14 +122,7 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Instruction::Arg value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); - } + Instruction::Arg value = VisitExpr(binding->value); this->var_arg_map_.insert({binding->var, value}); } } diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 9ac65f6f6eb1..9197448016c4 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -203,14 +203,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Optional VisitExpr_(const SeqExprNode* op) final { for (auto block : op->blocks) { for (Binding binding : block->bindings) { - Optional value; - if (auto* var_binding = binding.as()) { - value = this->VisitExpr(var_binding->value); - } else if (auto* match_cast = binding.as()) { - value = this->VisitExpr(match_cast->value); - } else { - LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey(); - } + Optional value = VisitExpr(binding->value); this->var_map_.insert({binding->var, value}); } } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5037161fcb90..921ca77e6c2c 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -226,29 +226,18 @@ class BlockBuilderImpl : public BlockBuilderNode { void EmitNormalized(Binding binding) final { BlockFrame* cur_frame = CurrentBlockFrame(); - if (const auto* var_binding = binding.as()) { - if (!cur_frame->is_dataflow) { - ICHECK(!var_binding->var.as()) - << "Cannot emit dataflow var in non-dataflow block"; - } - // normalized check - ICHECK(var_binding->var->struct_info_.defined()); - ICHECK(var_binding->value->struct_info_.defined()); - cur_frame->bindings.push_back(binding); - binding_table_[var_binding->var->vid] = var_binding->value; - } else if (const auto* match_cast = binding.as()) { - if (!cur_frame->is_dataflow) { - ICHECK(!match_cast->var.as()) - << "Cannot emit dataflow var in non-dataflow block"; - } - // normalized check - ICHECK(match_cast->var->struct_info_.defined()); - ICHECK(match_cast->value->struct_info_.defined()); + if (!cur_frame->is_dataflow) { + ICHECK(!binding->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; + } + // normalized check + ICHECK(binding->var->struct_info_.defined()); + ICHECK(binding->value->struct_info_.defined()); + cur_frame->bindings.push_back(binding); + if (binding.as()) { // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. - cur_frame->bindings.push_back(binding); - } else { - LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); + binding_table_[binding->var->vid] = binding->value; } } diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 14a704d729e3..6555b3ad36c9 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -519,11 +519,10 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { Array bindings; if (const auto* node = block.as()) { for (auto binding : node->bindings) { + Expr new_value = this->VisitExpr(binding->value); if (auto var_binding = binding.as()) { - Expr new_value = this->VisitExpr(var_binding->value); bindings.push_back(VarBinding(var_binding->var, new_value)); } else if (auto match_cast = binding.as()) { - Expr new_value = this->VisitExpr(match_cast->value); bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 246b38f6f83b..54f86f22ba4b 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -127,15 +127,10 @@ class CanonicalizePlanner : public ExprVisitor { void VisitBinding(const Binding& binding) override { bool has_same_struct_info = true; - Expr value; - if (auto ptr = binding.as()) { - value = ptr->value; - } else if (auto ptr = binding.as()) { + Expr value = binding->value; + if (auto ptr = binding.as()) { has_same_struct_info = StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value)); - value = ptr->value; - } else { - LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); } // Unwrap TupleGetItem, if the Tuple being accessed is known. @@ -229,7 +224,7 @@ class BindingCanonicalizer : public ExprMutator { for (int i = new_block->bindings.size() - 1; i >= 0; i--) { auto binding = new_block->bindings[i]; auto var = binding->var; - auto value = GetBoundValue(binding); + auto value = binding->value; if (var->IsInstance()) { auto df_var = Downcast(var); @@ -292,8 +287,8 @@ class BindingCanonicalizer : public ExprMutator { changed = true; continue; } else if (!binding->var->IsInstance() && - GetBoundValue(binding)->IsInstance() && - candidates.count(Downcast(GetBoundValue(binding)))) { + binding->value->IsInstance() && + candidates.count(Downcast(binding->value))) { changed = true; if (auto* match_binding = binding.as()) { auto new_binding = diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 5cabfc40ca9c..f19db72501fe 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -894,13 +894,7 @@ class OperatorFusor : public ExprMutator { } }; - if (const auto* var_binding = binding.as()) { - PostOrderVisit(var_binding->value, update_boundary); - } else { - const auto* match_cast = binding.as(); - ICHECK_NOTNULL(match_cast); - PostOrderVisit(match_cast->value, update_boundary); - } + PostOrderVisit(binding->value, update_boundary); } } diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 9e91e0759248..badabba9a84a 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -229,13 +229,13 @@ bool IsImpureCall(const Call& call) { } Expr GetBoundValue(const Binding& b) { - if (auto* var_binding = b.as()) { - return var_binding->value; - } else if (auto* match_binding = b.as()) { - return match_binding->value; - } else { - CHECK(false) << "Invalid binding (should never happen)"; + static bool first_usage = true; + if (first_usage) { + LOG(WARNING) << "Use of the GetBoundValue function is deprecated. " + << "The bound value can instead be accessed directly " + << "with 'binding->value'."; } + return b->value; } /*!