diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 885c47ef5845..14be6b751354 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -168,8 +168,10 @@ Array ExpandWildcards(const Pattern& clause_pat, const IRModule& mod) { if (auto clause_ctor = clause_pat.as()) { return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); + } else if (auto clause_tup = clause_pat.as()) { + return ExpandWildcardsTuple(GetRef(clause_tup), cand, mod); } else { - return ExpandWildcardsTuple(Downcast(clause_pat), cand, mod); + return {cand}; } } @@ -201,18 +203,9 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { - bool subpattern = - clause_ctor->patterns[i].as() || - clause_ctor->patterns[i].as(); - // for non-ADT fields, we can only have a wildcard for the value. - if (!subpattern) { - values_by_field.push_back({PatternWildcardNode::make()}); - } else { - // otherwise, recursively expand. - values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], - ctor_cand->patterns[i], - mod)); - } + values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], + ctor_cand->patterns[i], + mod)); } // generate new candidates using a cartesian product. @@ -243,18 +236,9 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { - bool subpattern = - clause_tuple->patterns[i].as() || - clause_tuple->patterns[i].as(); - // for non-ADT fields, we can only have a wildcard for the value - if (!subpattern) { - values_by_field.push_back({PatternWildcardNode::make()}); - } else { - // otherwise, recursively expand - values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], - tuple_cand->patterns[i], - mod)); - } + values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], + tuple_cand->patterns[i], + mod)); } // generate new candidates using a cartesian product diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 615d4e092291..1ac99a69a249 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -19,6 +19,7 @@ from tvm import relay from tvm.relay.prelude import Prelude from tvm.relay.analysis import unmatched_cases +import pytest def test_empty_match_block(): # empty match block will not match anything, so it should return a wildcard pattern @@ -273,3 +274,27 @@ def test_tuple_match(): clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) assert len(unmatched_cases(x)) == 0 + + +def test_inf_loop_case(): + code = """ +v0.0.4 +type Arith[A] { + Zero, + Const(A), + Plus(Arith[A], Arith[A]) +} + +def @shallow_opt[A](%a: Arith[A]) -> Arith[A] { + match (%a) { + Plus(Zero, %r) => %r, + Plus(%l, Zero) => %l, + _ => %a + } +} +""" + relay.fromtext(code) + # fromtext parse the module, then checked it (which include strictness checking). + +if __name__ == "__main__": + pytest.main([__file__])