Skip to content

Commit

Permalink
[Unity] Move VarBinding's and MatchCast's value to base class
Browse files Browse the repository at this point in the history
Prior to this commit, the `VarBinding` and `MatchCast` had equivalent
`Expr value` fields.  For use cases that need the bound
value (e.g. collecting known values), this required downcasting to each
subclass.  This commit moves the `Expr value` to the base `Binding`
class, removing the need for the downcasting.

No impact to either C++ or Python usage is expected.  On the C++ side,
all use of `VarBindingNode::value` or `MatchCastNode::value` will now
access `BindingNode::value`.  On the Python side, all dynamic access
of `binding_obj.value` will access the new field.
  • Loading branch information
Lunderberg committed Oct 26, 2023
1 parent 4d19c8a commit 0c1bb25
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 67 deletions.
9 changes: 4 additions & 5 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ class BindingNode : public Object {
public:
/*! \brief The return variable to bound to. */
Var var;

/*! \brief The binding value. */
Expr value;

mutable Span span;

static constexpr const char* _type_key = "relax.expr.Binding";
Expand Down Expand Up @@ -751,8 +755,6 @@ class Binding : public ObjectRef {
*/
class MatchCastNode : public BindingNode {
public:
/*! \brief The input value to match cast. */
Expr value;
/*! \brief The struct info pattern to match to. */
StructInfo struct_info;

Expand Down Expand Up @@ -796,9 +798,6 @@ class MatchCast : public Binding {

class VarBindingNode : public BindingNode {
public:
/*! \brief The binding value. */
Expr value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
Expand Down
9 changes: 1 addition & 8 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
Instruction::Arg VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Instruction::Arg value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
}
Instruction::Arg value = VisitExpr(binding->value);
this->var_arg_map_.insert({binding->var, value});
}
}
Expand Down
9 changes: 1 addition & 8 deletions src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,7 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
Optional<PrimExpr> VisitExpr_(const SeqExprNode* op) final {
for (auto block : op->blocks) {
for (Binding binding : block->bindings) {
Optional<PrimExpr> value;
if (auto* var_binding = binding.as<VarBindingNode>()) {
value = this->VisitExpr(var_binding->value);
} else if (auto* match_cast = binding.as<MatchCastNode>()) {
value = this->VisitExpr(match_cast->value);
} else {
LOG(FATAL) << "Unsupported binding " << binding->GetTypeKey();
}
Optional<PrimExpr> value = VisitExpr(binding->value);
this->var_map_.insert({binding->var, value});
}
}
Expand Down
31 changes: 10 additions & 21 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,29 +226,18 @@ class BlockBuilderImpl : public BlockBuilderNode {
void EmitNormalized(Binding binding) final {
BlockFrame* cur_frame = CurrentBlockFrame();

if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (!cur_frame->is_dataflow) {
ICHECK(!var_binding->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(var_binding->var->struct_info_.defined());
ICHECK(var_binding->value->struct_info_.defined());
cur_frame->bindings.push_back(binding);
binding_table_[var_binding->var->vid] = var_binding->value;
} else if (const auto* match_cast = binding.as<MatchCastNode>()) {
if (!cur_frame->is_dataflow) {
ICHECK(!match_cast->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(match_cast->var->struct_info_.defined());
ICHECK(match_cast->value->struct_info_.defined());
if (!cur_frame->is_dataflow) {
ICHECK(!binding->var.as<DataflowVarNode>())
<< "Cannot emit dataflow var in non-dataflow block";
}
// normalized check
ICHECK(binding->var->struct_info_.defined());
ICHECK(binding->value->struct_info_.defined());
cur_frame->bindings.push_back(binding);
if (binding.as<VarBindingNode>()) {
// NOTE match shape do not follow simple binding rule
// as a result should not appear in binding table.
cur_frame->bindings.push_back(binding);
} else {
LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey();
binding_table_[binding->var->vid] = binding->value;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,10 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) {
Array<Binding> bindings;
if (const auto* node = block.as<BindingBlockNode>()) {
for (auto binding : node->bindings) {
Expr new_value = this->VisitExpr(binding->value);
if (auto var_binding = binding.as<VarBindingNode>()) {
Expr new_value = this->VisitExpr(var_binding->value);
bindings.push_back(VarBinding(var_binding->var, new_value));
} else if (auto match_cast = binding.as<MatchCastNode>()) {
Expr new_value = this->VisitExpr(match_cast->value);
bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info));
} else {
LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey();
Expand Down
15 changes: 5 additions & 10 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,10 @@ class CanonicalizePlanner : public ExprVisitor {

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
Expr value = binding->value;
if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}

// Unwrap TupleGetItem, if the Tuple being accessed is known.
Expand Down Expand Up @@ -229,7 +224,7 @@ class BindingCanonicalizer : public ExprMutator {
for (int i = new_block->bindings.size() - 1; i >= 0; i--) {
auto binding = new_block->bindings[i];
auto var = binding->var;
auto value = GetBoundValue(binding);
auto value = binding->value;

if (var->IsInstance<DataflowVarNode>()) {
auto df_var = Downcast<DataflowVar>(var);
Expand Down Expand Up @@ -292,8 +287,8 @@ class BindingCanonicalizer : public ExprMutator {
changed = true;
continue;
} else if (!binding->var->IsInstance<DataflowVarNode>() &&
GetBoundValue(binding)->IsInstance<DataflowVarNode>() &&
candidates.count(Downcast<DataflowVar>(GetBoundValue(binding)))) {
binding->value->IsInstance<DataflowVarNode>() &&
candidates.count(Downcast<DataflowVar>(binding->value))) {
changed = true;
if (auto* match_binding = binding.as<MatchCastNode>()) {
auto new_binding =
Expand Down
8 changes: 1 addition & 7 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -894,13 +894,7 @@ class OperatorFusor : public ExprMutator {
}
};

if (const auto* var_binding = binding.as<VarBindingNode>()) {
PostOrderVisit(var_binding->value, update_boundary);
} else {
const auto* match_cast = binding.as<MatchCastNode>();
ICHECK_NOTNULL(match_cast);
PostOrderVisit(match_cast->value, update_boundary);
}
PostOrderVisit(binding->value, update_boundary);
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ bool IsImpureCall(const Call& call) {
}

Expr GetBoundValue(const Binding& b) {
if (auto* var_binding = b.as<VarBindingNode>()) {
return var_binding->value;
} else if (auto* match_binding = b.as<MatchCastNode>()) {
return match_binding->value;
} else {
CHECK(false) << "Invalid binding (should never happen)";
static bool first_usage = true;
if (first_usage) {
LOG(WARNING) << "Use of the GetBoundValue function is deprecated. "
<< "The bound value can instead be accessed directly "
<< "with 'binding->value'.";
}
return b->value;
}

/*!
Expand Down

0 comments on commit 0c1bb25

Please sign in to comment.