Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Resolve breakages in unit tests
Browse files Browse the repository at this point in the history
All breakage was the result of callers relying on ill-formed Relax
maintaining that specific type form of ill-formed-ness.
Lunderberg committed Apr 9, 2024
1 parent 15bc0d2 commit fbcf057
Showing 3 changed files with 36 additions and 9 deletions.
29 changes: 23 additions & 6 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
@@ -59,13 +59,30 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
return VisitDFPattern(pattern, expr);
}

static Expr TryGetValOfVar(const Expr& expr, const Map<Var, Expr>& var2val) {
if (var2val.empty()) return expr;
static Expr TryGetValOfVar(Expr expr, const Map<Var, Expr>& var2val) {
auto unwrap = [&](Expr expr) -> Optional<Expr> {
// Unwrap variables into the value to which they are bound.
if (var2val.size()) {
if (const VarNode* var = expr.as<VarNode>()) {
if (auto may = var2val.Get(GetRef<Var>(var))) {
return may.value();
}
}
}

// Unwrap SeqExpr with no bindings. These can occur due to Relax
// IR constraints for the bodies of Function and If nodes.
if (auto seq = expr.as<SeqExprNode>()) {
if (seq->blocks.empty()) {
return seq->body;
}
}

return NullOpt;
};

// if not match, try to match value of var if expr is a var.
if (const VarNode* var = expr.as<VarNode>()) {
auto may = var2val.Get(GetRef<Var>(var));
if (may.defined()) return may.value();
while (auto unwrapped = unwrap(expr)) {
expr = unwrapped.value();
}

return expr;
14 changes: 12 additions & 2 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
@@ -1257,9 +1257,19 @@ class CompositeFunctionAnnotator : public ExprMutator {
params.push_back(new_v);
}

// We cannot delegate to `ExprMutator::VisitExpr_(const FunctionNode*)` at this point, as it
// would recursively visit the Call node. However, we are still required to generate
// well-formed Relax IR. As a result, we need to build the SeqExpr ourselves.
Var local_func_var("local_func", GetStructInfo(f_inner));
Var output_var("output", f_inner->ret_struct_info);
SeqExpr new_body({BindingBlock({
VarBinding(local_func_var, f_inner),
VarBinding(output_var, Call(local_func_var, params)),
})},
output_var);

// pure if the inner func is pure (no need to force purity if it's forced for the inner func)
return Function(param_vars, Call(f_inner, params), func_node->ret_struct_info,
f_inner->is_pure);
return Function(param_vars, new_body, func_node->ret_struct_info, f_inner->is_pure);
}

private:
2 changes: 1 addition & 1 deletion tests/python/relax/test_expr_functor.py
Original file line number Diff line number Diff line change
@@ -439,7 +439,7 @@ def test_if():
if_node = relax.If(x, x, x)
basic_check(
if_node,
"\n".join(["If", "\tVar", "\tVar", "\tVar"]),
"\n".join(["If", "\tVar", "\tSeqExpr", "\t\tVar", "\tSeqExpr", "\t\tVar"]),
"\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]),
)

0 comments on commit fbcf057

Please sign in to comment.