Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Move VarBinding's and MatchCast's value to base class #15992

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ class BindingNode : public Object {
public:
/*! \brief The return variable to bound to. */
Var var;

/*! \brief The binding value. */
Expr value;

mutable Span span;

static constexpr const char* _type_key = "relax.expr.Binding";
Expand Down Expand Up @@ -751,8 +755,6 @@ class Binding : public ObjectRef {
*/
class MatchCastNode : public BindingNode {
public:
/*! \brief The input value to match cast. */
Expr value;
/*! \brief The struct info pattern to match to. */
StructInfo struct_info;

Expand Down Expand Up @@ -796,9 +798,6 @@ class MatchCast : public Binding {

class VarBindingNode : public BindingNode {
public:
/*! \brief The binding value. */
Expr value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intentionally keep the value in the subclass, because in match clas there is extra set of struct info to be considered(that can populates new values), and it is important to keep reminding the users of this.

The extra dispatch is worth overall for this explicitness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and closing this PR.


void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
Expand Down
9 changes: 1 addition & 8 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Instruction::Arg value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
}
Instruction::Arg value = VisitExpr(binding->value);
this->var_arg_map_.insert({binding->var, value});
}
}
Expand Down
9 changes: 1 addition & 8 deletions src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Optional<PrimExpr> value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
}
Optional<PrimExpr> value = VisitExpr(binding->value);
this->var_map_.insert({binding->var, value});
}
}
Expand Down
31 changes: 10 additions & 21 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,29 +226,18 @@ class BlockBuilderImpl : public BlockBuilderNode {
void EmitNormalized(Binding binding) final {
BlockFrame* cur_frame = CurrentBlockFrame();

if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (!cur_frame->is_dataflow) {
ICHECK(!var_binding->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(var_binding->var->struct_info_.defined());
ICHECK(var_binding->value->struct_info_.defined());
cur_frame->bindings.push_back(binding);
binding_table_[var_binding->var->vid] = var_binding->value;
} else if (const auto* match_cast = binding.as<MatchCastNode>()) {
if (!cur_frame->is_dataflow) {
ICHECK(!match_cast->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(match_cast->var->struct_info_.defined());
ICHECK(match_cast->value->struct_info_.defined());
if (!cur_frame->is_dataflow) {
ICHECK(!binding->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(binding->var->struct_info_.defined());
ICHECK(binding->value->struct_info_.defined());
cur_frame->bindings.push_back(binding);
if (binding.as<VarBindingNode>()) {
// NOTE match shape do not follow simple binding rule
// as a result should not appear in binding table.
cur_frame->bindings.push_back(binding);
} else {
LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey();
binding_table_[binding->var->vid] = binding->value;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,10 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) {
Array<Binding> bindings;
if (const auto* node = block.as<BindingBlockNode>()) {
for (auto binding : node->bindings) {
Expr new_value = this->VisitExpr(binding->value);
if (auto var_binding = binding.as<VarBindingNode>()) {
Expr new_value = this->VisitExpr(var_binding->value);
bindings.push_back(VarBinding(var_binding->var, new_value));
} else if (auto match_cast = binding.as<MatchCastNode>()) {
Expr new_value = this->VisitExpr(match_cast->value);
bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info));
} else {
LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
Expand Down
15 changes: 5 additions & 10 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ class CanonicalizePlanner : public ExprVisitor {

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
Expr value = binding->value;
if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}

// Unwrap TupleGetItem, if the Tuple being accessed is known.
Expand Down Expand Up @@ -229,7 +224,7 @@ class BindingCanonicalizer : public ExprMutator {
for (int i = new_block->bindings.size() - 1; i >= 0; i--) {
auto binding = new_block->bindings[i];
auto var = binding->var;
auto value = GetBoundValue(binding);
auto value = binding->value;

if (var->IsInstance<DataflowVarNode>()) {
auto df_var = Downcast<DataflowVar>(var);
Expand Down Expand Up @@ -292,8 +287,8 @@ class BindingCanonicalizer : public ExprMutator {
changed = true;
continue;
} else if (!binding->var->IsInstance<DataflowVarNode>() &&
GetBoundValue(binding)->IsInstance<DataflowVarNode>() &&
candidates.count(Downcast<DataflowVar>(GetBoundValue(binding)))) {
binding->value->IsInstance<DataflowVarNode>() &&
candidates.count(Downcast<DataflowVar>(binding->value))) {
changed = true;
if (auto* match_binding = binding.as<MatchCastNode>()) {
auto new_binding =
Expand Down
8 changes: 1 addition & 7 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -894,13 +894,7 @@ class OperatorFusor : public ExprMutator {
}
};

if (const auto* var_binding = binding.as<VarBindingNode>()) {
PostOrderVisit(var_binding->value, update_boundary);
} else {
const auto* match_cast = binding.as<MatchCastNode>();
ICHECK_NOTNULL(match_cast);
PostOrderVisit(match_cast->value, update_boundary);
}
PostOrderVisit(binding->value, update_boundary);
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ bool IsImpureCall(const Call& call) {
}

Expr GetBoundValue(const Binding& b) {
if (auto* var_binding = b.as<VarBindingNode>()) {
return var_binding->value;
} else if (auto* match_binding = b.as<MatchCastNode>()) {
return match_binding->value;
} else {
CHECK(false) << "Invalid binding (should never happen)";
static bool first_usage = true;
if (first_usage) {
LOG(WARNING) << "Use of the GetBoundValue function is deprecated. "
<< "The bound value can instead be accessed directly "
<< "with 'binding->value'.";
}
return b->value;
}

/*!
Expand Down
Loading