-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add oneflow.nn.functional.depend api #9807
Conversation
@@ -1005,6 +1005,9 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() { | |||
// pinned identity can be pruned since GenerateOptimizerOpConfs pass has | |||
// already construct a complete computational graph | |||
JUST(DoPass("PrunePinnedIdentityOpPass")); | |||
// prune depend OP and and add ctrl_in_op to op_conf accordingly | |||
// to express the same semantics and avoid performance loss | |||
JUST(DoPass("PruneDependOpPass")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥放到这个位置,而不是更前面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新代码,将PruneDependOpPass提前到PruneAmpWhiteIdentityOpPass前。
理由:
(1)将PruneDepend尽早执行,可以发掘更多的算子优化空间(如删除Depend OP后可能满足FuseAddToOutputPass的执行条件);
(2)但在前面的部分Pass在删除或更新OP时未考虑控制边的转移或保持(如EliminateDeadNodesPass、AutoMixedPrecision)。如果放在它们之前执行,新添加的控制边可能丢失导致失效。
经阅读前面的Pass代码和测试,将PruneDependOpPass的执行提前到PruneAmpWhiteIdentityOpPass之前比较合适。
if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); } | ||
if (ctx->depend_tensor_requires_grad) { | ||
in_grads->at(1) = | ||
JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), out_grads.at(0)->dtype(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果要实现反向的话,dtype和device应该和depend_tensor一样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新代码,depend_tensor梯度的dtype和device与depend_tensor一致。
// GetRelativeNodes() considers the chain of multiple depend OP Nodes and processes them | ||
// from top to down, so skip the intermediate nodes | ||
if (!IsDependOPNodeAtTop(op_node, del_nodes)) { continue; } | ||
const std::vector<RelativeNodes> relatives = GetRelativeNodes(op_node, del_nodes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一段逻辑有点晦涩,有没有一些graph之类的注释,更直观些
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
This comment was marked as outdated.
Sorry, something went wrong.
Pass的代码看了下,目前没看出啥问题,但目前的算法逻辑还是太晦涩了。 |
已按此思路重构Pass的代码 |
|
||
// Step 1.3 process src nodes | ||
const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node); | ||
if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果可以边遍历边改图,这段逻辑可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因缺少API支持,这种写法比较困难。
参考已有的较为简单的Pass,比如EliminateDeadNodesPass 和 PruneAmpWhiteIdentityOpPass,它们理论上可以边遍历边改图,但没有这样做,而是走产生OpGraph - > 分析OpGraph 并记录变更->根据变更修改Job对象的流程。
CHECK(src_node); | ||
const OpNode* nearest_depend_node = node_info.second.nearest_depend_node; | ||
const auto& old_lbi = nearest_depend_node->op().BnInOp2Lbi(nearest_depend_node->op().SoleObn()); | ||
const auto& new_lbi = src_node->op().BnInOp2Lbi(src_node->op().SoleObn()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src_node可能不止一个输出吧,比如:
a, b = op0()
c = op1()
b = depend(b, c)
d = op2(b)
得通过连接边来判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新代码,用于处理src_node多个输出的情况,并针对这种情况追加了单例测试(test_depend_graph_case7)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,看着没啥子大问题了。
不过当前depend op不支持对source op添加控制边来控制执行顺序。比如下面三个op没有输入,就没法在用户侧通过depend op来控制,这种场景可以以后再考虑
a = op0()
b = op1()
c = op2()
CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src); | ||
const auto& old_lbi = | ||
depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn()); | ||
const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是直接用depend_node_nearest_src的输入就可以了。不过目前这样也没有问题就是了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以。不过,148~168行的逻辑(Step 1.3)涉及对src_node的更新,去掉对src_node的记录会显得这段逻辑不那么自然……
@@ -2764,6 +2764,10 @@ | |||
signature: "Tensor (Tensor input) => IsFinite" | |||
bind_python: True | |||
|
|||
- name: "depend" | |||
signature: "Tensor (Tensor input, Tensor depend_tensor) => Depend" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数签名要优化下么?比如Tensor depend_tensor是不是直接叫depend就好了,这里会考虑和一个list of tensor建立控制边么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1)Tensor depend_tensor 已重命名为 depend;
(2)已支持传入depend的类型为Tensor或List[Tensor],并为List[Tensor]的情形追加了测试样例。
@@ -532,6 +532,17 @@ def set_prune_amp_white_identity_ops(func_desc, value=True): | |||
func_desc.job_config_proto.prune_amp_white_identity_ops = value | |||
|
|||
|
|||
@oneflow_function_config("prune_depend_ops") | |||
def set_prune_depend_ops(func_desc, value=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是不是没有用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph 的控制接口现在都在 https://oneflow.readthedocs.io/en/master/graph.html#config-options-on-a-graph
这个 python/oneflow/framework/function_util.py 是计划移除的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,没有用。之前有询问过这里的代码,答复是“function_util.py 里面是 0.4 之前的接口,代码会清理掉。”
我可以删掉这段。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph 的控制接口现在都在 https://oneflow.readthedocs.io/en/master/graph.html#config-options-on-a-graph
这个 python/oneflow/framework/function_util.py 是计划移除的
我是这样想的
(a)预计这个OP很少被使用
(b)优化这个OP的Pass相对安全
(c)避免config太多,用户看文档太花时间
综合考虑就没有在config里添这个Pass开关了。
如有必要,我可以添加
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如有必要,我可以添加
那可以先不加
增加一个Python OP,用于:
(1)防止指定OP在静态图优化时被消除或重排序;
(2)作为用户增加静态图控制边的接口,实现对执行序的约束或修改。
该OP存在于其他具有静态图特性的框架,例如:
Mindspore(https://www.mindspore.cn/docs/zh-CN/r1.9/api_python/ops/mindspore.ops.Depend.html)
Tensorflow (https://www.tensorflow.org/api_docs/python/tf/control_dependencies)
特性:
(1)为避免Eager Mode下的性能损失,Python接口判别在Eager Mode还是在Grpah Mode下运行,Eager Mode直接返回输入;
(2)为避免Grpah Mode下的性能损失,增加可配置开关的Pass,用于消除多添加的OP,并相应的添加底层的控制边;
(3)考虑self-loop导致的死锁的情况;
(4)Pass考虑了多个depend OP连锁的情况,以及可能重复添加控制边的情况;
(5)Kernel直接复用已有代码;
(6)包含了单元测试(考虑用户多种可能的用法)和文档
效果:
以单元测试的第一个例子(test_depend_graph_case0)为例
网络定义
不使用nn.functional.depend时,job_TestGraph_0_plan.dot的截图

从图可知,OP “model-scalar_mul-0”和“model.linear-matmul-1”并无执行顺序的约束,且从ID大小推测后者将先与前者执行,与用户定义的OP执行顺序不一致
使用nn.functional.depend后,job_TestGraph_0_plan.dot的截图

从图可知,OP “model-scalar_mul-0”与“model.linear-matmul-1”之间增加了一条控制边,且从ID大小推测前者将先与后者执行,达到用户控制OP执行顺序的目的。且由于存在控制边,防止了 “model-scalar_mul-0”被其他Pass消除
PS: 第一次向OF贡献算子,希望能被接纳~
如有要补充的请及时告知~