-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add oneflow.nn.functional.depend api #9807
Changes from 9 commits
7fcdb35
7245f9f
1b5db60
1de065a
ddddbf6
137cfd8
354b007
2379f64
ad334bc
80e094e
fd13735
cd592da
994226d
5ff01d2
0ba2e63
25cc35c
684c882
feed392
56fb105
8c34c5f
fa2efad
e72c1da
ba854bd
5d35a37
439e229
0e42ddf
4a7ee4f
3c4d87c
7829fb2
7ce7f52
4dce37b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/op_expr_grad_function.h" | ||
#include "oneflow/core/functional/functional.h" | ||
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
struct DependCaptureState : public AutoGradCaptureState { | ||
bool in_requires_grad = false; | ||
bool depend_tensor_requires_grad = false; | ||
Shape depend_tensor_shape; | ||
}; | ||
|
||
class Depend : public OpExprGradFunction<DependCaptureState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } | ||
|
||
Maybe<void> Capture(DependCaptureState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, const AttrMap& attrs) const override { | ||
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) | ||
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) | ||
ctx->in_requires_grad = inputs.at(0)->requires_grad(); | ||
ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad(); | ||
if (ctx->depend_tensor_requires_grad) { ctx->depend_tensor_shape = *(inputs.at(1)->shape()); } | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const DependCaptureState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) | ||
in_grads->resize(2); | ||
if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); } | ||
if (ctx->depend_tensor_requires_grad) { | ||
in_grads->at(1) = | ||
JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), out_grads.at(0)->dtype(), | ||
JUST(out_grads.at(0)->device()))); | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend); | ||
|
||
} // namespace one | ||
} // namespace oneflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2749,6 +2749,10 @@ | |
signature: "Tensor (Tensor input) => IsFinite" | ||
bind_python: True | ||
|
||
- name: "depend" | ||
signature: "Tensor (Tensor input, Tensor depend_tensor) => Depend" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 函数签名要优化下么?比如Tensor depend_tensor是不是直接叫depend就好了,这里会考虑和一个list of tensor建立控制边么 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (1)Tensor depend_tensor 已重命名为 depend; |
||
bind_python: True | ||
|
||
- name: "roc_auc_score" | ||
signature: "Tensor (Tensor label, Tensor pred) => RocAucScore" | ||
bind_python: True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1005,6 +1005,9 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() { | |
// pinned identity can be pruned since GenerateOptimizerOpConfs pass has | ||
// already construct a complete computational graph | ||
JUST(DoPass("PrunePinnedIdentityOpPass")); | ||
// prune depend OP and and add ctrl_in_op to op_conf accordingly | ||
// to express the same semantics and avoid performance loss | ||
JUST(DoPass("PruneDependOpPass")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥放到这个位置,而不是更前面 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已更新代码,将PruneDependOpPass提前到PruneAmpWhiteIdentityOpPass前。 经阅读前面的Pass代码和测试,将PruneDependOpPass的执行提前到PruneAmpWhiteIdentityOpPass之前比较合适。 |
||
JUST(DoPass("ReplaceEmbeddingOps")); | ||
JUST(DoPass("SequentialOneEmbeddingOpsPass")); | ||
JUST(DoPass("FuseEmbeddingShuffleInteractionPass")); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include <glog/logging.h> | ||
#include <string> | ||
#include <vector> | ||
#include "oneflow/core/framework/framework.h" | ||
#include "oneflow/core/graph/node.h" | ||
#include "oneflow/core/graph/op_graph.h" | ||
#include "oneflow/core/job_rewriter/job_pass.h" | ||
#include "oneflow/core/register/logical_blob_id.pb.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace { | ||
|
||
struct RelativeNodes { | ||
const OpNode* input_node = nullptr; | ||
const OpNode* output_node = nullptr; | ||
const OpNode* nearest_del_node = nullptr; | ||
std::vector<const OpNode*> in_ctrl_nodes = {}; | ||
}; | ||
|
||
bool IsDependyOp(const OperatorConf& op) { | ||
return op.has_user_conf() && (op.user_conf().op_type_name() == "depend"); | ||
} | ||
|
||
bool NeedDoPass(const Job& job) { | ||
return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp); | ||
} | ||
|
||
const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node, | ||
const std::string& target_tensor_name) { | ||
CHECK(IsDependyOp(op_node->op().op_conf())); | ||
for (const OpEdge* in_edge : op_node->in_edges()) { | ||
const OpNode* in_op_node = in_edge->src_node(); | ||
const std::string& in_op_node_name = in_op_node->op().op_name(); | ||
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns = in_edge->lbi2ibns(); | ||
|
||
for (const auto& item : lbi2ibns) { | ||
const std::string& lbi_op_name = item.first.op_name(); | ||
for (const std::string& tensor_name : item.second) { | ||
if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) { | ||
return in_op_node; | ||
} | ||
} | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
const OpNode* GetNodeFromInputEdge(const OpNode* op_node) { | ||
return GetNodeFromEdgeByTensorName(op_node, "in_0"); | ||
} | ||
|
||
const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) { | ||
return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0"); | ||
} | ||
|
||
bool IsDependOPNodeAtTop(const OpNode* op_node, HashSet<const OpNode*>& del_nodes) { | ||
CHECK(IsDependyOp(op_node->op().op_conf())); | ||
const OpNode* input_op_node = GetNodeFromInputEdge(op_node); | ||
const OpNode* in_ctrl_op_node = GetNodeFromInCtrlEdge(op_node); | ||
if (del_nodes.find(input_op_node) == del_nodes.end() | ||
&& del_nodes.find(in_ctrl_op_node) == del_nodes.end()) { | ||
return true; | ||
} else { | ||
return false; | ||
} | ||
} | ||
|
||
void GetRelativeNodesHelper(const OpNode* op_node, const HashSet<const OpNode*>& del_nodes, | ||
const OpNode* input_node, std::vector<const OpNode*> in_ctrl_nodes, | ||
std::vector<RelativeNodes>& ret) { | ||
CHECK(IsDependyOp(op_node->op().op_conf())); | ||
for (const OpEdge* out_edge : op_node->out_edges()) { | ||
const OpNode* out_op_node = out_edge->dst_node(); | ||
if (del_nodes.find(out_op_node) == del_nodes.end()) { | ||
// "out_op_node" is one of valid output nodes | ||
|
||
// in this case, record the nodes as result | ||
const OpNode* in_ctrl_node_to_check = GetNodeFromInCtrlEdge(op_node); | ||
if (del_nodes.find(in_ctrl_node_to_check) == del_nodes.end()) { | ||
in_ctrl_nodes.emplace_back(in_ctrl_node_to_check); | ||
} | ||
|
||
const OpNode* input_node_to_check = GetNodeFromInputEdge(op_node); | ||
if (del_nodes.find(input_node_to_check) == del_nodes.end()) { | ||
// should not have two input nodes for a depend OP Chain | ||
CHECK(input_node == nullptr); | ||
input_node = input_node_to_check; | ||
} | ||
ret.push_back({input_node, out_op_node, op_node, in_ctrl_nodes}); | ||
} else if (op_node == GetNodeFromInCtrlEdge(out_op_node)) { | ||
// "out_op_node" is ALSO a depend OP Node, and "op_node" is an in-control OP Node | ||
|
||
// in this case, two precursor node of "op_node" should be interpreted as in-control OP Node | ||
// the "input_node" should not NOT be the precursor of the target output node | ||
// thus, put "input_node into" "depend_nodes", and set "input_node" as NULL in subsequent | ||
// processing | ||
if (input_node) in_ctrl_nodes.push_back(input_node); | ||
input_node = nullptr; | ||
// continue recursion until the target output node is found | ||
GetRelativeNodesHelper(out_op_node, del_nodes, input_node, in_ctrl_nodes, ret); | ||
} else { | ||
// "out_op_node" is ALSO a depend OP Node, and "op_node" is an input OP Node | ||
|
||
// in this case, "input_node" should be the real precursor of the target output node | ||
// thus, append in-ctrl OP Node into "in_ctrl_nodes", and update or remain "input_node" | ||
// in subsequent processing | ||
const OpNode* in_ctrl_node_to_check = GetNodeFromInCtrlEdge(op_node); | ||
if (del_nodes.find(in_ctrl_node_to_check) == del_nodes.end()) { | ||
in_ctrl_nodes.emplace_back(in_ctrl_node_to_check); | ||
} | ||
|
||
const OpNode* input_node_to_check = GetNodeFromInputEdge(op_node); | ||
if (del_nodes.find(input_node_to_check) == del_nodes.end()) { | ||
// should not have two input nodes for a depend OP Chain | ||
CHECK(input_node == nullptr); | ||
input_node = input_node_to_check; | ||
} | ||
// continue recursion until the target output node is found | ||
GetRelativeNodesHelper(out_op_node, del_nodes, input_node, in_ctrl_nodes, ret); | ||
} | ||
} | ||
} | ||
|
||
const std::vector<RelativeNodes> GetRelativeNodes(const OpNode* op_node, | ||
const HashSet<const OpNode*>& del_nodes) { | ||
std::vector<RelativeNodes> ret; | ||
GetRelativeNodesHelper(op_node, del_nodes, nullptr, {}, ret); | ||
return ret; | ||
} | ||
|
||
class PruneDependOpPass final : public JobPass { | ||
public: | ||
PruneDependOpPass() = default; | ||
~PruneDependOpPass() override = default; | ||
|
||
Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override; | ||
}; | ||
|
||
Maybe<void> PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const { | ||
if (!ctx->job_desc().prune_depend_ops()) { return Maybe<void>::Ok(); } | ||
if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); } | ||
const OpGraph op_graph(*job); | ||
|
||
HashSet<std::string> ctrl_in_op_names; | ||
op_graph.ForEachNode([&](const OpNode* op_node) { | ||
for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { | ||
ctrl_in_op_names.insert(ctrl_in_op_name); | ||
} | ||
}); | ||
|
||
HashSet<const OpNode*> del_nodes; | ||
op_graph.ForEachNode([&](const OpNode* op_node) { | ||
const std::string& op_name = op_node->op().op_name(); | ||
const OperatorConf& op_conf = op_node->op().op_conf(); | ||
// not depend op | ||
if (!IsDependyOp(op_conf)) { return; } | ||
// has ctrl in | ||
if (!op_conf.ctrl_in_op_name().empty()) { return; } | ||
// is ctrl in of another op | ||
if (ctrl_in_op_names.find(op_name) != ctrl_in_op_names.end()) { return; } | ||
|
||
del_nodes.insert(op_node); | ||
}); | ||
|
||
HashMap<std::string, OperatorConf> to_update_op_confs; | ||
std::vector<std::string> del_op_names; | ||
del_op_names.reserve(del_nodes.size()); | ||
for (const OpNode* op_node : del_nodes) { | ||
del_op_names.emplace_back(op_node->op().op_name()); | ||
// GetRelativeNodes() considers the chain of multiple depend OP Nodes and processes them | ||
// from top to down, so skip the intermediate nodes | ||
if (!IsDependOPNodeAtTop(op_node, del_nodes)) { continue; } | ||
const std::vector<RelativeNodes> relatives = GetRelativeNodes(op_node, del_nodes); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一段逻辑有点晦涩,有没有一些graph之类的注释,更直观些
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
||
// adjust op_conf of nodes connected to depend OP Nodes | ||
for (const RelativeNodes& item : relatives) { | ||
const OpNode* input_node = item.input_node; | ||
const OpNode* output_node = item.output_node; | ||
const OpNode* nearest_del_node = item.nearest_del_node; | ||
const std::vector<const OpNode*>& depend_nodes = item.in_ctrl_nodes; | ||
// in some cases (e.g. the second branch in GetRelativeNodesHelper()), input nodes could | ||
// be interpreted as in-ctrl node, accordingly their input_node will be NULL and the ibn | ||
// modifications should be skip | ||
if (input_node) { | ||
const auto& old_lbi = nearest_del_node->op().BnInOp2Lbi(nearest_del_node->op().SoleObn()); | ||
const auto& new_lbi = input_node->op().BnInOp2Lbi(input_node->op().SoleObn()); | ||
const Operator& out_op = output_node->op(); | ||
for (const std::string& ibn : out_op.input_bns()) { | ||
if (out_op.BnInOp2Lbi(ibn) == old_lbi) { | ||
auto iter = to_update_op_confs.find(out_op.op_name()); | ||
if (iter == to_update_op_confs.end()) { | ||
iter = to_update_op_confs.emplace(out_op.op_name(), out_op.op_conf()).first; | ||
} | ||
OperatorConf& op_conf = iter->second; | ||
const auto& old_val = | ||
ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, GenLogicalBlobName(new_lbi)); | ||
CHECK_EQ_OR_RETURN(GenLogicalBlobName(old_lbi), old_val); | ||
} | ||
} | ||
} | ||
// add ctrl_in_op | ||
const Operator& out_op = output_node->op(); | ||
auto out_iter = to_update_op_confs.find(out_op.op_name()); | ||
if (out_iter == to_update_op_confs.end()) { | ||
out_iter = to_update_op_confs.emplace(out_op.op_name(), out_op.op_conf()).first; | ||
} | ||
OperatorConf& out_op_conf = out_iter->second; | ||
for (const OpNode* node : depend_nodes) { | ||
CHECK(output_node != node); // self-loop found | ||
const auto& existed_ctrl_in_op_names = op_node->op().op_conf().ctrl_in_op_name(); | ||
const std::string& new_ctrl_in_op_name = node->op().op_name(); | ||
auto existed_it = std::find(existed_ctrl_in_op_names.begin(), | ||
existed_ctrl_in_op_names.end(), new_ctrl_in_op_name); | ||
// avoid adding input node or duplicate control nodes | ||
if (node != input_node && existed_it == existed_ctrl_in_op_names.end()) { | ||
out_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name); | ||
} | ||
} | ||
} | ||
} | ||
|
||
JobBuilder job_builder(job); | ||
for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); } | ||
job_builder.DelOps(del_op_names); | ||
|
||
return Maybe<void>::Ok(); | ||
} | ||
|
||
} // namespace | ||
|
||
REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass); | ||
|
||
} // namespace oneflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果要实现反向的话,dtype和device应该和depend_tensor一样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新代码,depend_tensor梯度的dtype和device与depend_tensor一致。