-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If #43967
[Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If #43967
Conversation
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn, | ||
return_name_ids) | ||
except Exception as e: | ||
if re.search("Unsupported return type of true_fn and false_fn in cond", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里最好加个单测确保这行代码按照预期触发,因为报错信息有可能被别人迭代优化后,这个分支就失效了
while True: | ||
pred = cond() | ||
if isinstance(pred, Variable): | ||
raise Dygraph2StaticException( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise这个error时会触发用户源码行的标记么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if re.search("Unsupported return type of true_fn and false_fn in cond", | ||
str(e)): | ||
raise Dygraph2StaticException( | ||
"Your if/else have different return type. TODO: add link to modifty." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Your if/else have different return type. TODO: add link to modifty." | |
# TODO: add link to modifty. | |
"Your if/else has different return type. %s" % e |
这里如果还没有加TODO的话,最好先把之前的err msg re-throw下
"Incompatible return values of true_fn and false_fn in cond", | ||
str(e)): | ||
raise Dygraph2StaticException( | ||
"Your if/else have different number of return value. TODO: add link to modifty." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -52,6 +52,8 @@ def __init__(self, wrapper_root): | |||
), "Type of input node should be AstNodeWrapper, but received %s ." % type( | |||
wrapper_root) | |||
self.root = wrapper_root.node | |||
FunctionNameLivenessAnalysis( | |||
self.root) # name analysis of current ast tree. | |||
self.static_analysis_visitor = StaticAnalysisVisitor(self.root) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.static_analysis_visitor 没有用到了,这个我们是不是可以删除掉了?
def create_undefined_var_like(variable): | ||
""" create a undefined var with the same shape and dtype like varaible. | ||
""" | ||
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以把RETURN_NO_VALUE_MAGIC_NUM 也放到这个文件,这里import虽然时动态import,单最好让utils成为一个叶子结点的文件,可以被其他文件import。可以后续PR优化
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM | ||
var = data_layer_not_check(unique_name.generate("undefined_var"), | ||
variable.shape, variable.dtype) | ||
assign(RETURN_NO_VALUE_MAGIC_NUM, var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么还需要一个assign?不可以通过name来过滤么
template = """ | ||
def {func_name}(): | ||
nonlocal {nonlocal_vars} | ||
{nonlocal_vars} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果 nonlocal_vars 是空,还需要return names么?
@@ -1159,7 +1181,10 @@ def assign_skip_lod_tensor_array(input, output): | |||
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block. | |||
""" | |||
if not isinstance(input, Variable) and not isinstance(input, core.VarBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not isinstance(input, Variable) and not isinstance(input, core.VarBase): | |
if not isinstance(input, (Variable ,core.VarBase)): |
@@ -2377,7 +2402,7 @@ def copy_var_to_parent_block(var, layer_helper): | |||
return parent_block_var | |||
|
|||
|
|||
def cond(pred, true_fn=None, false_fn=None, name=None): | |||
def cond(pred, true_fn=None, false_fn=None, return_name_ids=None, name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def cond(pred, true_fn=None, false_fn=None, return_name_ids=None, name=None): | |
def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): |
为了兼容性,新增参数一般放到最后
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -2423,6 +2454,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): | |||
true. The default value is ``None`` . | |||
false_fn(callable, optional): A callable to be performed if ``pred`` is | |||
false. The default value is ``None`` . | |||
return_names: A list of strings to represents the name of returned vars. useful to debug. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_names: -> return_names(支持的数据类型, optional): XXX. Default is None, means XXX.
useful to debug,太 Chinese English 了,要改一下。
顺序在 name 后面。
PR types
Others
PR changes
Others
Describe
FunctionScopeVisitor Enhance and substitute the original NameVisitor in If.
这个PR主要目标是增强了FunctionScopeVisitor的功能,支持控制流的 name 解析。并且替换了IF的NameVisitor和名字解析逻辑。其余的修改是针对上述修改的单测修复。
新的FunctionScopeVisitor主要有如下几个优点:
TODO: