Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of IfExp (ternary) #179

Merged
merged 1 commit into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/example_inputs/ternary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
result = (
"abc"
if t.u == v.w else
"def"
if x else
y # This is the only RHS variable which taints result
if func(z if 1 + 1 == 2 else z) else
"ghi"
)
77 changes: 76 additions & 1 deletion pyt/core/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,80 @@ def visit_Return(self, node):
return self.visit_chain(node)


class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
class IfExpRewriter(ast.NodeTransformer):
"""Splits IfExp ternary expressions containing complex tests into multiple statements

Will change

a if b(c) else d

into

a if __if_exp_0 else d

with Assign nodes in assignments [__if_exp_0 = b(c)]
"""

def __init__(self, starting_index=0):
self._temporary_variable_index = starting_index
self.assignments = []
super().__init__()

def visit_IfExp(self, node):
if isinstance(node.test, (ast.Name, ast.Attribute)):
return self.generic_visit(node)
else:
temp_var_id = '__if_exp_{}'.format(self._temporary_variable_index)
self._temporary_variable_index += 1
assignment_of_test = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=self.visit(node.test),
)
ast.copy_location(assignment_of_test, node)
self.assignments.append(assignment_of_test)
transformed_if_exp = ast.IfExp(
test=ast.Name(id=temp_var_id, ctx=ast.Load()),
body=self.visit(node.body),
orelse=self.visit(node.orelse),
)
ast.copy_location(transformed_if_exp, node)
return transformed_if_exp

def visit_FunctionDef(self, node):
return node


class IfExpTransformer:
"""Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""

def visit_body(self, nodes):
new_nodes = []
count = 0
for node in nodes:
rewriter = IfExpRewriter(count)
possibly_transformed_node = rewriter.visit(node)
if rewriter.assignments:
new_nodes.extend(rewriter.assignments)
count += len(rewriter.assignments)
new_nodes.append(possibly_transformed_node)
return new_nodes

def visit_FunctionDef(self, node):
transformed = ast.FunctionDef(
name=node.name,
args=node.args,
body=self.visit_body(node.body),
decorator_list=node.decorator_list,
returns=node.returns
)
ast.copy_location(transformed, node)
return self.generic_visit(transformed)

def visit_Module(self, node):
transformed = ast.Module(self.visit_body(node.body))
ast.copy_location(transformed, node)
return self.generic_visit(transformed)


class PytTransformer(AsyncTransformer, IfExpTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
pass
9 changes: 9 additions & 0 deletions pyt/helper_visitors/label_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,12 @@ def visit_FormattedValue(self, node):
def visit_Starred(self, node):
self.result += '*'
self.visit(node.value)

def visit_IfExp(self, node):
self.result += '('
self.visit(node.test)
self.result += ') ? ('
self.visit(node.body)
self.result += ') : ('
self.visit(node.orelse)
self.result += ')'
5 changes: 5 additions & 0 deletions pyt/helper_visitors/right_hand_side_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def visit_Call(self, node):
for keyword in node.keywords:
self.visit(keyword)

def visit_IfExp(self, node):
# The test doesn't taint the assignment
self.visit(node.body)
self.visit(node.orelse)

@classmethod
def result_for_node(cls, node):
visitor = cls()
Expand Down
33 changes: 33 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,39 @@ def test_if_not(self):
(_exit, _if)
])

def test_ternary_ifexp(self):
self.cfg_create_from_file('examples/example_inputs/ternary.py')

# entry = 0
tmp_if_1 = 1
# tmp_if_inner = 2
call = 3
# tmp_if_call = 4
actual_if_exp = 5
exit = 6

self.assert_length(self.cfg.nodes, expected_length=exit + 1)
self.assertInCfg([
(i + 1, i) for i in range(exit)
])

self.assertCountEqual(
self.cfg.nodes[actual_if_exp].right_hand_side_variables,
['y'],
"The variables in the test expressions shouldn't appear as RHS variables"
)

self.assertCountEqual(
self.cfg.nodes[tmp_if_1].right_hand_side_variables,
['t', 'v'],
)

self.assertIn(
'ret_func(',
self.cfg.nodes[call].label,
"Function calls inside the test expressions should still appear in the CFG",
)


class CFGWhileTest(CFGBaseTestCase):

Expand Down
16 changes: 16 additions & 0 deletions tests/core/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,19 @@ def test_chained_function(self):

transformed = PytTransformer().visit(chained_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

def test_if_exp(self):
complex_if_exp_tree = ast.parse("\n".join([
"def a():",
" b = c if d.e(f) else g if h else i if j.k(l) else m",
]))

separated_tree = ast.parse("\n".join([
"def a():",
" __if_exp_0 = d.e(f)",
" __if_exp_1 = j.k(l)",
" b = c if __if_exp_0 else g if h else i if __if_exp_1 else m",
]))

transformed = PytTransformer().visit(complex_if_exp_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
4 changes: 4 additions & 0 deletions tests/helper_visitors/label_visitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ def test_joined_str_with_format_spec(self):
def test_starred(self):
label = self.perform_labeling_on_expression('[a, *b] = *c, d')
self.assertEqual(label.result, '[a, *b] = (*c, d)')

def test_if_exp(self):
label = self.perform_labeling_on_expression('a = b if c else d')
self.assertEqual(label.result, 'a = (c) ? (b) : (d)')