Skip to content

Commit

Permalink
hoist global / nonlocal declarations outside of instrumentation branch
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Feb 29, 2024
1 parent c392ac8 commit 78c19ec
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions pyccolo/stmt_inserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DefaultDict,
Dict,
List,
Tuple,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -62,20 +63,17 @@ def _get_parsed_append_stmt(
return ret


class StripGlobalAndNonlocalDeclarations(ast.NodeTransformer):
def visit_Global(self, node: ast.Global) -> ast.Pass:
with fast.location_of(node):
return fast.Pass()

def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Pass:
with fast.location_of(node):
return fast.Pass()

def __call__(self, node: _T) -> _T:
return super().visit(fast.copy_ast(node))

def visit(self, node: _T) -> _T:
return super().visit(fast.copy_ast(node))
def strip_globals_and_nonlocals(
body: List[ast.stmt],
) -> Tuple[List[ast.stmt], List[ast.stmt]]:
new_body: List[ast.stmt] = []
globals_and_nonlocals: List[ast.stmt] = []
for stmt in body:
if isinstance(stmt, (ast.Global, ast.Nonlocal)):
globals_and_nonlocals.append(stmt)
else:
new_body.append(stmt)
return new_body, globals_and_nonlocals


class StatementInserter(ast.NodeTransformer, EmitterMixin):
Expand All @@ -93,18 +91,18 @@ def __init__(
handler_predicate_by_event,
handler_guards_by_event,
)
self._global_nonlocal_stripper: StripGlobalAndNonlocalDeclarations = (
StripGlobalAndNonlocalDeclarations()
)

def _handle_loop_body(
self, node: Union[ast.For, ast.While], orig_body: List[ast.AST]
) -> List[ast.AST]:
loop_node_copy = cast(
Union[ast.For, ast.While], self.orig_to_copy_mapping[id(node)]
Union[ast.For, ast.While],
fast.copy_ast(self.orig_to_copy_mapping[id(node)]),
)
loop_node_copy.body, globals_and_nonlocals = strip_globals_and_nonlocals(
loop_node_copy.body
)
if self.global_guards_enabled:
loop_node_copy = self._global_nonlocal_stripper.visit(loop_node_copy)
loop_guard = make_guard_name(loop_node_copy)
self.register_guard(loop_guard)
else:
Expand Down Expand Up @@ -153,7 +151,7 @@ def _handle_loop_body(
orelse=loop_node_copy.body,
)
]
return ret
return globals_and_nonlocals + ret

def _handle_function_body(
self,
Expand All @@ -162,10 +160,12 @@ def _handle_function_body(
) -> List[ast.AST]:
fundef_copy = cast(
Union[ast.FunctionDef, ast.AsyncFunctionDef],
self.orig_to_copy_mapping[id(node)],
fast.copy_ast(self.orig_to_copy_mapping[id(node)]),
)
fundef_copy.body, globals_and_nonlocals = strip_globals_and_nonlocals(
fundef_copy.body
)
if self.global_guards_enabled:
fundef_copy = self._global_nonlocal_stripper.visit(fundef_copy)
function_guard = make_guard_name(fundef_copy)
self.register_guard(function_guard)
else:
Expand All @@ -177,6 +177,7 @@ def _handle_function_body(
and isinstance(orig_body[0].value, ast.Str)
):
docstring = [orig_body.pop(0)]
fundef_copy.body.pop(0)
if len(orig_body) == 0:
return docstring
with fast.location_of(fundef_copy):
Expand Down Expand Up @@ -224,12 +225,10 @@ def _handle_function_body(
TraceEvent.after_function_execution
](fundef_copy)
else orig_body,
orelse=fundef_copy.body
if len(docstring) == 0
else fundef_copy.body[len(docstring) :], # noqa: E203
orelse=fundef_copy.body,
),
]
return docstring + ret
return docstring + globals_and_nonlocals + ret

def _handle_module_body(
self, node: ast.Module, orig_body: List[ast.stmt]
Expand Down Expand Up @@ -276,6 +275,10 @@ def _handle_stmt(
self, node: ast.AST, field_name: str, inner_node: ast.stmt
) -> List[ast.stmt]:
stmts_to_extend: List[ast.stmt] = []
if isinstance(
node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.For, ast.While)
) and isinstance(inner_node, (ast.Global, ast.Nonlocal)):
return stmts_to_extend
stmt_copy = cast(ast.stmt, self.orig_to_copy_mapping[id(inner_node)])
main_and_maybe_after = self._make_main_and_after_stmt_stmts(
node, field_name, inner_node, stmt_copy
Expand Down

0 comments on commit 78c19ec

Please sign in to comment.