Skip to content

Commit

Permalink
Merge pull request #186 from adrianbn/133_while_test_with_func
Browse files Browse the repository at this point in the history
Visit functions in while test (#133)
  • Loading branch information
bcaller authored Nov 21, 2018
2 parents ce56a20 + 9cb0b56 commit 4b495ad
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 6 deletions.
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return True

while foo():
print(x)
x += 1
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator_lhs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return 6

while foo() > x:
print(x)
x += 1
6 changes: 6 additions & 0 deletions examples/example_inputs/while_func_comparator_rhs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def foo():
return 6

while x < foo():
print(x)
x += 1
33 changes: 27 additions & 6 deletions pyt/cfg/stmt_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,23 +555,44 @@ def visit_For(self, node):
path=self.filenames[-1]
))

if isinstance(node.iter, ast.Call) and get_call_names_as_string(node.iter.func) in self.function_names:
last_node = self.visit(node.iter)
last_node.connect(for_node)
self.process_loop_funcs(node.iter, for_node)

return self.loop_node_skeleton(for_node, node)

def process_loop_funcs(self, comp_n, loop_node):
"""
If the loop test node contains function calls, it connects the loop node to the nodes of
those function calls.
:param comp_n: The test node of a loop that may contain functions.
:param loop_node: The loop node itself to connect to the new function nodes if any
:return: None
"""
if isinstance(comp_n, ast.Call) and get_call_names_as_string(comp_n.func) in self.function_names:
last_node = self.visit(comp_n)
last_node.connect(loop_node)

def visit_While(self, node):
label_visitor = LabelVisitor()
label_visitor.visit(node.test)
test = node.test # the test condition of the while loop
label_visitor.visit(test)

test = self.append_node(Node(
while_node = self.append_node(Node(
'while ' + label_visitor.result + ':',
node,
path=self.filenames[-1]
))

return self.loop_node_skeleton(test, node)
if isinstance(test, ast.Compare):
# quirk. See https://greentreesnakes.readthedocs.io/en/latest/nodes.html#Compare
self.process_loop_funcs(test.left, while_node)

for comp in test.comparators:
self.process_loop_funcs(comp, while_node)
else: # while foo():
self.process_loop_funcs(test, while_node)

return self.loop_node_skeleton(while_node, node)

def add_blackbox_or_builtin_call(self, node, blackbox): # noqa: C901
"""Processes a blackbox or builtin function when it is called.
Expand Down
90 changes: 90 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,96 @@ def test_while_line_numbers(self):
self.assertLineNumber(else_body_2, 6)
self.assertLineNumber(next_stmt, 7)

def test_while_func_comparator(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while foo():')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])

def test_while_func_comparator_rhs(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_rhs.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while x < foo():')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])

def test_while_func_comparator_lhs(self):
self.cfg_create_from_file('examples/example_inputs/while_func_comparator_lhs.py')

self.assert_length(self.cfg.nodes, expected_length=9)

entry = 0
test = 1
entry_foo = 2
ret_foo = 3
exit_foo = 4
call_foo = 5
_print = 6
body_1 = 7
_exit = 8

self.assertEqual(self.cfg.nodes[test].label, 'while foo() > x:')

self.assertInCfg([
(test, entry),
(entry_foo, test),
(_print, test),
(_exit, test),
(body_1, _print),
(test, body_1),
(test, call_foo),
(ret_foo, entry_foo),
(exit_foo, ret_foo),
(call_foo, exit_foo)
])


class CFGAssignmentMultiTest(CFGBaseTestCase):
def test_assignment_multi_target(self):
Expand Down

0 comments on commit 4b495ad

Please sign in to comment.