diff --git a/examples/example_inputs/while_func_comparator.py b/examples/example_inputs/while_func_comparator.py new file mode 100644 index 00000000..8c775f72 --- /dev/null +++ b/examples/example_inputs/while_func_comparator.py @@ -0,0 +1,6 @@ +def foo(): + return True + +while foo(): + print(x) + x += 1 diff --git a/examples/example_inputs/while_func_comparator_lhs.py b/examples/example_inputs/while_func_comparator_lhs.py new file mode 100644 index 00000000..1904e8e7 --- /dev/null +++ b/examples/example_inputs/while_func_comparator_lhs.py @@ -0,0 +1,6 @@ +def foo(): + return 6 + +while foo() > x: + print(x) + x += 1 diff --git a/examples/example_inputs/while_func_comparator_rhs.py b/examples/example_inputs/while_func_comparator_rhs.py new file mode 100644 index 00000000..6aafc2b6 --- /dev/null +++ b/examples/example_inputs/while_func_comparator_rhs.py @@ -0,0 +1,6 @@ +def foo(): + return 6 + +while x < foo(): + print(x) + x += 1 diff --git a/pyt/cfg/stmt_visitor.py b/pyt/cfg/stmt_visitor.py index 95913211..3b9d5f48 100644 --- a/pyt/cfg/stmt_visitor.py +++ b/pyt/cfg/stmt_visitor.py @@ -555,23 +555,44 @@ def visit_For(self, node): path=self.filenames[-1] )) - if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names: - last_node = self.visit(node.iter) - last_node.connect(for_node) + self.process_loop_funcs(node.iter, for_node) return self.loop_node_skeleton(for_node, node) + def process_loop_funcs(self, comp_n, loop_node): + """ + If the loop test node contains function calls, it connects the loop node to the nodes of + those function calls. + + :param comp_n: The test node of a loop that may contain functions. + :param loop_node: The loop node itself to connect to the new function nodes if any + :return: None + """ + if isinstance(comp_n, ast.Call) and get_call_names_as_string(comp_n.func) in self.function_names: + last_node = self.visit(comp_n) + last_node.connect(loop_node) + def visit_While(self, node): label_visitor = LabelVisitor() - label_visitor.visit(node.test) + test = node.test # the test condition of the while loop + label_visitor.visit(test) - test = self.append_node(Node( + while_node = self.append_node(Node( 'while ' + label_visitor.result + ':', node, path=self.filenames[-1] )) - return self.loop_node_skeleton(test, node) + if isinstance(test, ast.Compare): + # quirk. See https://greentreesnakes.readthedocs.io/en/latest/nodes.html#Compare + self.process_loop_funcs(test.left, while_node) + + for comp in test.comparators: + self.process_loop_funcs(comp, while_node) + else: # while foo(): + self.process_loop_funcs(test, while_node) + + return self.loop_node_skeleton(while_node, node) def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901 """Processes a blackbox or builtin function when it is called. diff --git a/tests/cfg/cfg_test.py b/tests/cfg/cfg_test.py index a4c24ba5..3af37942 100644 --- a/tests/cfg/cfg_test.py +++ b/tests/cfg/cfg_test.py @@ -684,6 +684,96 @@ def test_while_line_numbers(self): self.assertLineNumber(else_body_2, 6) self.assertLineNumber(next_stmt, 7) + def test_while_func_comparator(self): + self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py') + + self.assert_length(self.cfg.nodes, expected_length=9) + + entry = 0 + test = 1 + entry_foo = 2 + ret_foo = 3 + exit_foo = 4 + call_foo = 5 + _print = 6 + body_1 = 7 + _exit = 8 + + self.assertEqual(self.cfg.nodes[test].label, 'while foo():') + + self.assertInCfg([ + (test, entry), + (entry_foo, test), + (_print, test), + (_exit, test), + (body_1, _print), + (test, body_1), + (test, call_foo), + (ret_foo, entry_foo), + (exit_foo, ret_foo), + (call_foo, exit_foo) + ]) + + def test_while_func_comparator_rhs(self): + self.cfg_create_from_file('examples/example_inputs/while_func_comparator_rhs.py') + + self.assert_length(self.cfg.nodes, expected_length=9) + + entry = 0 + test = 1 + entry_foo = 2 + ret_foo = 3 + exit_foo = 4 + call_foo = 5 + _print = 6 + body_1 = 7 + _exit = 8 + + self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():') + + self.assertInCfg([ + (test, entry), + (entry_foo, test), + (_print, test), + (_exit, test), + (body_1, _print), + (test, body_1), + (test, call_foo), + (ret_foo, entry_foo), + (exit_foo, ret_foo), + (call_foo, exit_foo) + ]) + + def test_while_func_comparator_lhs(self): + self.cfg_create_from_file('examples/example_inputs/while_func_comparator_lhs.py') + + self.assert_length(self.cfg.nodes, expected_length=9) + + entry = 0 + test = 1 + entry_foo = 2 + ret_foo = 3 + exit_foo = 4 + call_foo = 5 + _print = 6 + body_1 = 7 + _exit = 8 + + self.assertEqual(self.cfg.nodes[test].label, 'while foo() > x:') + + self.assertInCfg([ + (test, entry), + (entry_foo, test), + (_print, test), + (_exit, test), + (body_1, _print), + (test, body_1), + (test, call_foo), + (ret_foo, entry_foo), + (exit_foo, ret_foo), + (call_foo, exit_foo) + ]) + class CFGAssignmentMultiTest(CFGBaseTestCase): def test_assignment_multi_target(self):