Skip to content

Commit

Permalink
Merge pull request #179 from bcaller/ifexp
Browse files Browse the repository at this point in the history
Better handling of IfExp (ternary)
  • Loading branch information
bcaller authored Oct 30, 2018
2 parents 5d7a94b + 2e4f8c9 commit 0932cc9
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 1 deletion.
9 changes: 9 additions & 0 deletions examples/example_inputs/ternary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
result = (
"abc"
if t.u == v.w else
"def"
if x else
y # This is the only RHS variable which taints result
if func(z if 1 + 1 == 2 else z) else
"ghi"
)
77 changes: 76 additions & 1 deletion pyt/core/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,80 @@ def visit_Return(self, node):
return self.visit_chain(node)


class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
class IfExpRewriter(ast.NodeTransformer):
"""Splits IfExp ternary expressions containing complex tests into multiple statements
Will change
a if b(c) else d
into
a if __if_exp_0 else d
with Assign nodes in assignments [__if_exp_0 = b(c)]
"""

def __init__(self, starting_index=0):
self._temporary_variable_index = starting_index
self.assignments = []
super().__init__()

def visit_IfExp(self, node):
if isinstance(node.test, (ast.Name, ast.Attribute)):
return self.generic_visit(node)
else:
temp_var_id = '__if_exp_{}'.format(self._temporary_variable_index)
self._temporary_variable_index += 1
assignment_of_test = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=self.visit(node.test),
)
ast.copy_location(assignment_of_test, node)
self.assignments.append(assignment_of_test)
transformed_if_exp = ast.IfExp(
test=ast.Name(id=temp_var_id, ctx=ast.Load()),
body=self.visit(node.body),
orelse=self.visit(node.orelse),
)
ast.copy_location(transformed_if_exp, node)
return transformed_if_exp

def visit_FunctionDef(self, node):
return node


class IfExpTransformer:
"""Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""

def visit_body(self, nodes):
new_nodes = []
count = 0
for node in nodes:
rewriter = IfExpRewriter(count)
possibly_transformed_node = rewriter.visit(node)
if rewriter.assignments:
new_nodes.extend(rewriter.assignments)
count += len(rewriter.assignments)
new_nodes.append(possibly_transformed_node)
return new_nodes

def visit_FunctionDef(self, node):
transformed = ast.FunctionDef(
name=node.name,
args=node.args,
body=self.visit_body(node.body),
decorator_list=node.decorator_list,
returns=node.returns
)
ast.copy_location(transformed, node)
return self.generic_visit(transformed)

def visit_Module(self, node):
transformed = ast.Module(self.visit_body(node.body))
ast.copy_location(transformed, node)
return self.generic_visit(transformed)


class PytTransformer(AsyncTransformer, IfExpTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
pass
9 changes: 9 additions & 0 deletions pyt/helper_visitors/label_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,12 @@ def visit_FormattedValue(self, node):
def visit_Starred(self, node):
self.result += '*'
self.visit(node.value)

def visit_IfExp(self, node):
self.result += '('
self.visit(node.test)
self.result += ') ? ('
self.visit(node.body)
self.result += ') : ('
self.visit(node.orelse)
self.result += ')'
5 changes: 5 additions & 0 deletions pyt/helper_visitors/right_hand_side_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def visit_Call(self, node):
for keyword in node.keywords:
self.visit(keyword)

def visit_IfExp(self, node):
# The test doesn't taint the assignment
self.visit(node.body)
self.visit(node.orelse)

@classmethod
def result_for_node(cls, node):
visitor = cls()
Expand Down
33 changes: 33 additions & 0 deletions tests/cfg/cfg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,39 @@ def test_if_not(self):
(_exit, _if)
])

def test_ternary_ifexp(self):
self.cfg_create_from_file('examples/example_inputs/ternary.py')

# entry = 0
tmp_if_1 = 1
# tmp_if_inner = 2
call = 3
# tmp_if_call = 4
actual_if_exp = 5
exit = 6

self.assert_length(self.cfg.nodes, expected_length=exit + 1)
self.assertInCfg([
(i + 1, i) for i in range(exit)
])

self.assertCountEqual(
self.cfg.nodes[actual_if_exp].right_hand_side_variables,
['y'],
"The variables in the test expressions shouldn't appear as RHS variables"
)

self.assertCountEqual(
self.cfg.nodes[tmp_if_1].right_hand_side_variables,
['t', 'v'],
)

self.assertIn(
'ret_func(',
self.cfg.nodes[call].label,
"Function calls inside the test expressions should still appear in the CFG",
)


class CFGWhileTest(CFGBaseTestCase):

Expand Down
16 changes: 16 additions & 0 deletions tests/core/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,19 @@ def test_chained_function(self):

transformed = PytTransformer().visit(chained_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

def test_if_exp(self):
complex_if_exp_tree = ast.parse("\n".join([
"def a():",
" b = c if d.e(f) else g if h else i if j.k(l) else m",
]))

separated_tree = ast.parse("\n".join([
"def a():",
" __if_exp_0 = d.e(f)",
" __if_exp_1 = j.k(l)",
" b = c if __if_exp_0 else g if h else i if __if_exp_1 else m",
]))

transformed = PytTransformer().visit(complex_if_exp_tree)
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
4 changes: 4 additions & 0 deletions tests/helper_visitors/label_visitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ def test_joined_str_with_format_spec(self):
def test_starred(self):
label = self.perform_labeling_on_expression('[a, *b] = *c, d')
self.assertEqual(label.result, '[a, *b] = (*c, d)')

def test_if_exp(self):
label = self.perform_labeling_on_expression('a = b if c else d')
self.assertEqual(label.result, 'a = (c) ? (b) : (d)')

0 comments on commit 0932cc9

Please sign in to comment.