Skip to content

Commit

Permalink
[WIP] Fixing an Infinite Loop case in UnmatchedChecker. (apache#4881)
Browse files Browse the repository at this point in the history
* save

* save

* remove

* remove cerr
  • Loading branch information
MarisaKirisame authored and zhiics committed Mar 2, 2020
1 parent a6888fa commit e9d0b5e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
34 changes: 9 additions & 25 deletions src/relay/pass/match_exhaustion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
} else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
} else {
return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
return {cand};
}
}

Expand Down Expand Up @@ -201,18 +203,9 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
bool subpattern =
clause_ctor->patterns[i].as<PatternConstructorNode>() ||
clause_ctor->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value.
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand.
values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
ctor_cand->patterns[i],
mod));
}
values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
ctor_cand->patterns[i],
mod));
}

// generate new candidates using a cartesian product.
Expand Down Expand Up @@ -243,18 +236,9 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
bool subpattern =
clause_tuple->patterns[i].as<PatternConstructorNode>() ||
clause_tuple->patterns[i].as<PatternTupleNode>();
// for non-ADT fields, we can only have a wildcard for the value
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
} else {
// otherwise, recursively expand
values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
tuple_cand->patterns[i],
mod));
}
values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
tuple_cand->patterns[i],
mod));
}

// generate new candidates using a cartesian product
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_pass_unmatched_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.analysis import unmatched_cases
import pytest

def test_empty_match_block():
# empty match block will not match anything, so it should return a wildcard pattern
Expand Down Expand Up @@ -273,3 +274,27 @@ def test_tuple_match():
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert len(unmatched_cases(x)) == 0


def test_inf_loop_case():
code = """
v0.0.4
type Arith[A] {
Zero,
Const(A),
Plus(Arith[A], Arith[A])
}
def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
match (%a) {
Plus(Zero, %r) => %r,
Plus(%l, Zero) => %l,
_ => %a
}
}
"""
relay.fromtext(code)
# fromtext parse the module, then checked it (which include strictness checking).

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e9d0b5e

Please sign in to comment.