diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 9969dd80f5ed..7668fa99e611 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -324,10 +324,18 @@ def _eval_compare(self, fields: Dict[str, Any]) -> Any: res : Any The evaluation result. """ - value = self._eval_expr(fields["left"]) - for op, rhs in zip(fields["ops"], fields["comparators"]): - value = _eval_op(op, values=[value, self._eval_expr(rhs)]) - return value + values = [self._eval_expr(fields["left"])] + values.extend([self._eval_expr(rhs) for rhs in fields["comparators"]]) + result = None + assert len(fields["ops"]) == len(values) - 1 + + for index, op in enumerate(fields["ops"]): + sub_result = _eval_op(op, values=[values[index], values[index + 1]]) + if result is None: + result = sub_result + else: + result = _eval_op(doc.And(), values=[result, sub_result]) + return result def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: """The doc AST unary operation node evaluating method. diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index d28e4680ae16..f1569be5b1f4 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -626,5 +626,25 @@ def expected(A: T.buffer((128, 128), "float32")): tvm.ir.assert_structural_equal(func, expected) +def test_sequence_compare(): + @T.prim_func(private=True) + def tir_func(A: T.Buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + if 0 < i < 128 and 0 < j < 128: + A[i, j] = 1 + else: + A[i, j] = 0 + + @T.prim_func(private=True) + def expected(A: T.buffer((128, 128), "float32")): + for i, j in T.grid(128, 128): + if (0 < i and i < 128) and (0 < j and j < 128): + A[i, j] = 1 + else: + A[i, j] = 0 + + tvm.ir.assert_structural_equal(tir_func, expected) + + if __name__ == "__main__": tvm.testing.main()