Skip to content

Commit

Permalink
[BugFix] fixed a bug that deco_name can't be parsed corrected (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#46297) (PaddlePaddle#46366)

* use re replace judge by case

* simplify re
  • Loading branch information
feifei-111 authored Sep 23, 2022
1 parent 980292c commit cbf3f4b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings

import re
from paddle.fluid.dygraph.dygraph_to_static.utils import RE_PYNAME, RE_PYMODULE

IGNORE_NAMES = [
'declarative', 'to_static', 'dygraph_to_static_func', 'wraps',
Expand Down Expand Up @@ -65,25 +66,30 @@ def visit_FunctionDef(self, node):

for deco in reversed(deco_list):
# skip INGNORE_NAMES
if isinstance(deco, gast.Attribute):
deco_name = deco.attr
elif isinstance(deco, gast.Call):
if hasattr(deco.func, 'args'):
deco_name = deco.func.args[0].id
elif hasattr(deco.func, 'attr'):
deco_name = deco.func.attr
else:
deco_name = deco.func.id
deco_full_name = ast_to_source_code(deco).strip()
if isinstance(deco, gast.Call):
# match case like :
# 1: @_jst.Call(a.b.c.d.deco)()
# 2: @q.w.e.r.deco()
re_tmp = re.match(
r'({module})*({name}\(){{0,1}}({module})*({name})(\)){{0,1}}\(.*$'
.format(name=RE_PYNAME, module=RE_PYMODULE), deco_full_name)
deco_name = re_tmp.group(4)
else:
deco_name = deco.id
# match case like:
# @a.d.g.deco
re_tmp = re.match(
r'({module})*({name})$'.format(name=RE_PYNAME,
module=RE_PYMODULE),
deco_full_name)
deco_name = re_tmp.group(2)
if deco_name in IGNORE_NAMES:
continue
elif deco_name == 'contextmanager':
warnings.warn(
"Dy2Static : A context manager decorator is used, this may not work correctly after transform."
)

deco_full_name = ast_to_source_code(deco).strip()
decoed_func = '_decoedby_' + deco_name

# get function after decoration
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def visit(self, node):
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip'

RE_PYNAME = '[a-zA-Z0-9_]+'
RE_PYMODULE = '[a-zA-Z0-9_]+\.'

# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
FullArgSpec = collections.namedtuple('FullArgSpec', [
Expand Down

0 comments on commit cbf3f4b

Please sign in to comment.