Skip to content

Commit

Permalink
[ Dy2Static ] Add closure analysis for control flow and add some unit…
Browse files Browse the repository at this point in the history
…test (PaddlePaddle#43713)

* add closure analysis for control flow and add some unittest

* finetune the design of FunctionScopeVisitor

* fix

* fix python check

* fix code by code review
  • Loading branch information
2742195759 authored and sneaxiy committed Jun 27, 2022
1 parent 6087536 commit 1f58606
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,99 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
return ret


class NameScope:

def __init__(self):
""" we don't analyze the read only variable
because they keep the same in control flow.
"""
self.globals = set()
self.nonlocals = set()
self.args = set()
self.w_vars = set() # all vars been stored,
# may be globals or non-locals
def created_vars(self):
return self.w_vars - self.globals - self.nonlocals - self.args

def write_vars(self):
return self.w_vars

def global_vars(self):
return self.globals


class FunctionNameLivenessAnalysis(gast.NodeVisitor):
""" analyze the liveness of a function.
every variables stored in this scope will be collected,
in addition with global/nonlocal information.
1. global variable is stored in node.var_globals.
2. nonlocal variable is stored in node.var_nonlocals.
3. arguments is stored in node.var_args.
For example:
def func(*args, **kargs):
a = 12
global i,j
nonlocal x,y
print(a)
i = k
for m in range(10):
q = 12
After this visitor we have:
# node is the FunctionDef node with name: "func"
node.pd_scope = NameScope(
globals = ['i', 'j'],
nonlocals = ['x', 'y'],
args = ['args', 'kargs'],
wr_vars = ['a', 'i', 'q', 'm']
)
"""

def __init__(self, root_node):
self.funcdef_stack = []
self.visit(root_node)

def _current_funcdef_scope(self):
return self.funcdef_stack[-1].pd_scope

def visit_Name(self, node):
self.generic_visit(node)
write_context = (gast.Store, gast.AugStore, gast.Del)
if isinstance(node.ctx, write_context):
self._current_funcdef_scope().w_vars.add(node.id)

def visit_FunctionDef(self, node):
setattr(node, 'pd_scope', NameScope())
self.funcdef_stack.append(node)
self._current_funcdef_scope().args |= set(
self._get_argument_names(node))
self.generic_visit(node)
self.funcdef_stack.pop()

def visit_Global(self, node):
self._current_funcdef_scope().globals |= set(node.names)

def visit_Nonlocal(self, node):
self._current_funcdef_scope().nonlocals |= set(node.names)

def _get_argument_names(self, node):
""" get all arguments name in the functiondef node.
this node is local to the function and shouldn't
be created.
"""
assert isinstance(
node, gast.FunctionDef), "Input node is not function define node"
names = [a for a in node.args.args]
names.append(node.args.vararg)
names.append(node.args.kwarg)
names = [i.id for i in names if i is not None]
return names


class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
Expand All @@ -122,7 +215,6 @@ def __init__(self, root_node):

# List of nodes that have scope of variables.
self.nodes_with_scope = []

self.blacklist_names = {"False", "True", "None"}

# Mapping from gast.While/gast.For to variable nodes
Expand Down Expand Up @@ -244,6 +336,7 @@ def visit_Name(self, node):
type(gast.AugStore()),
type(gast.Del())
}

for loop_node in self.current_loop:
self.in_loop_vars[loop_node].append(node)
if type(node.ctx) in write_context:
Expand All @@ -255,6 +348,7 @@ def visit_Name(self, node):
def visit_FunctionDef(self, node):
self.nodes_with_scope.append(node)
self.blacklist_names.add(node.name)

# The variables in the function are not visible to the outside scope.
before_func_seen_vars = copy.copy(self.current_seen_vars)

Expand Down Expand Up @@ -353,6 +447,9 @@ def _is_call_func_name_node(self, node):
return True
return False

def _is_global_or_nonlocal(self, node):
return False

def _is_ancestor_node(self, ancestor_node, node):
parent_node = self._get_parent_node(node)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import paddle
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import FunctionNameLivenessAnalysis
from paddle.utils import gast
import inspect


class JudgeVisitor(gast.NodeVisitor):

def __init__(self, ans):
self.ans = ans

def visit_FunctionDef(self, node):
scope = node.pd_scope
expected = self.ans.get(node.name, set())
assert scope.created_vars() == expected, "Not Equals."
self.generic_visit(node)


def test_normal_0(x):

def func():
if True:
i = 1

func()
return i


def test_normal_argument(x):
x = 1

def func():
if True:
print(x)
i = 1

func()
return x


def test_global(x):
global t
t = 10

def func():
if True:
print(x)
i = 1

func()
return x


def test_nonlocal(x, *args, **kargs):
i = 10

def func(*args, **kargs):
nonlocal i
k = 10
if True:
print(x)
i = 1

func(*args, **kargs)
return x


class TestClosureAnalysis(unittest.TestCase):

def setUp(self):
self.init_dygraph_func()

def init_dygraph_func(self):
self.all_dygraph_funcs = [
test_nonlocal, test_global, test_normal_0, test_normal_argument
]
self.answer = [
{
'func': set('k'),
'test_nonlocal': set('i')
},
{
'func': set({'i'}),
},
{
'func': set('i'),
},
{
'func': set('i'),
},
]

def test_main(self):
for ans, func in zip(self.answer, self.all_dygraph_funcs):
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = FunctionNameLivenessAnalysis(gast_root)
JudgeVisitor(ans).visit(gast_root)


def TestClosureAnalysis_Attribute_func():
# in this function, only self is a Name, self.current is a Attribute. self is read and self.current.function is store()
i = 0
self.current.function = 12


class TestClosureAnalysis_Attribute(TestClosureAnalysis):

def init_dygraph_func(self):

self.all_dygraph_funcs = [TestClosureAnalysis_Attribute_func]
self.answer = [{"TestClosureAnalysis_Attribute_func": set({'i'})}]


if __name__ == '__main__':
unittest.main()

0 comments on commit 1f58606

Please sign in to comment.