-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add oneflow.nn.functional.depend api #9807
Changes from 25 commits
7fcdb35
7245f9f
1b5db60
1de065a
ddddbf6
137cfd8
354b007
2379f64
ad334bc
80e094e
fd13735
cd592da
994226d
5ff01d2
0ba2e63
25cc35c
684c882
feed392
56fb105
8c34c5f
fa2efad
e72c1da
ba854bd
5d35a37
439e229
0e42ddf
4a7ee4f
3c4d87c
7829fb2
7ce7f52
4dce37b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/op_expr_grad_function.h" | ||
#include "oneflow/core/functional/functional.h" | ||
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
struct DependCaptureState : public AutoGradCaptureState { | ||
bool in_requires_grad = false; | ||
bool depend_tensor_requires_grad = false; | ||
Shape depend_tensor_shape; | ||
Symbol<DType> depend_tensor_dtype; | ||
Maybe<Symbol<Device>> depend_tensor_device; | ||
}; | ||
|
||
class Depend : public OpExprGradFunction<DependCaptureState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); } | ||
|
||
Maybe<void> Capture(DependCaptureState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, const AttrMap& attrs) const override { | ||
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) | ||
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) | ||
ctx->in_requires_grad = inputs.at(0)->requires_grad(); | ||
ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad(); | ||
if (ctx->depend_tensor_requires_grad) { | ||
ctx->depend_tensor_shape = *(inputs.at(1)->shape()); | ||
ctx->depend_tensor_dtype = inputs.at(1)->dtype(); | ||
ctx->depend_tensor_device = inputs.at(1)->device(); | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const DependCaptureState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) | ||
in_grads->resize(2); | ||
if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); } | ||
if (ctx->depend_tensor_requires_grad) { | ||
in_grads->at(1) = | ||
JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), ctx->depend_tensor_dtype, | ||
JUST(ctx->depend_tensor_device))); | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend); | ||
|
||
} // namespace one | ||
} // namespace oneflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include <glog/logging.h> | ||
#include <string> | ||
#include <vector> | ||
#include "oneflow/core/common/hash_container.h" | ||
#include "oneflow/core/framework/framework.h" | ||
#include "oneflow/core/graph/node.h" | ||
#include "oneflow/core/graph/op_graph.h" | ||
#include "oneflow/core/job_rewriter/job_pass.h" | ||
#include "oneflow/core/register/logical_blob_id.pb.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace { | ||
|
||
struct UpdatedNodeInfo { | ||
const OpNode* node = nullptr; | ||
const OpNode* new_src_node = nullptr; | ||
const OpNode* depend_node_nearest_src = nullptr; | ||
const OpNode* depend_node_nearest_dst = nullptr; | ||
std::vector<const OpNode*> new_in_ctrl_nodes; | ||
bool updated = false; | ||
}; | ||
|
||
bool IsDependyOp(const OperatorConf& op) { | ||
return op.has_user_conf() && (op.user_conf().op_type_name() == "depend"); | ||
} | ||
|
||
bool NeedDoPass(const Job& job) { | ||
return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp); | ||
} | ||
|
||
const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node, | ||
const std::string& target_tensor_name) { | ||
CHECK(IsDependyOp(op_node->op().op_conf())); | ||
for (const OpEdge* in_edge : op_node->in_edges()) { | ||
const OpNode* in_op_node = in_edge->src_node(); | ||
const std::string& in_op_node_name = in_op_node->op().op_name(); | ||
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns = in_edge->lbi2ibns(); | ||
|
||
for (const auto& item : lbi2ibns) { | ||
const std::string& lbi_op_name = item.first.op_name(); | ||
for (const std::string& tensor_name : item.second) { | ||
if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) { | ||
return in_op_node; | ||
} | ||
} | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
const OpNode* GetNodeFromInputEdge(const OpNode* op_node) { | ||
return GetNodeFromEdgeByTensorName(op_node, "in_0"); | ||
} | ||
|
||
const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) { | ||
return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0"); | ||
} | ||
|
||
LogicalBlobId GetNewLbi(const OpNode* src_node, const OpNode* depend_node_nearest_src) { | ||
CHECK(IsDependyOp(depend_node_nearest_src->op().op_conf())); | ||
for (const OpEdge* out_edge : src_node->out_edges()) { | ||
const OpNode* dst_node = out_edge->dst_node(); | ||
if (dst_node != depend_node_nearest_src) { continue; } | ||
|
||
CHECK(out_edge->lbis().size() == 1); | ||
return out_edge->lbis()[0]; | ||
} | ||
// should not reach here | ||
CHECK(false); | ||
return {}; | ||
} | ||
|
||
class PruneDependOpPass final : public JobPass { | ||
public: | ||
PruneDependOpPass() = default; | ||
~PruneDependOpPass() override = default; | ||
|
||
Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override; | ||
}; | ||
|
||
Maybe<void> PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const { | ||
if (!ctx->job_desc().prune_depend_ops()) { return Maybe<void>::Ok(); } | ||
if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); } | ||
const OpGraph op_graph(*job); | ||
|
||
HashMap<std::string, UpdatedNodeInfo> node_info_with_update; | ||
std::vector<const OpNode*> ordered_nodes; | ||
|
||
// Step 0: topological sort, setup a map for recording modification | ||
op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { | ||
UpdatedNodeInfo node_info; | ||
node_info.node = node; | ||
node_info_with_update.emplace(node->op().op_name(), node_info); | ||
ordered_nodes.emplace_back(node); | ||
}); | ||
|
||
// Step 1: process node by topological order | ||
// record modification info when meet Depend OP nodes | ||
for (const OpNode* cur_node : ordered_nodes) { | ||
const std::string& cur_op_name = cur_node->op().op_name(); | ||
const OperatorConf& cur_op_conf = cur_node->op().op_conf(); | ||
if (!IsDependyOp(cur_op_conf)) { continue; } | ||
|
||
// record modification info to each dst_node | ||
for (const OpEdge* out_edge : cur_node->out_edges()) { | ||
const OpNode* dst_node = out_edge->dst_node(); | ||
const Operator& dst_op = dst_node->op(); | ||
|
||
UpdatedNodeInfo& updated_dst_node_info = node_info_with_update.find(dst_op.op_name())->second; | ||
UpdatedNodeInfo& updated_cur_node_info = node_info_with_update.find(cur_op_name)->second; | ||
updated_dst_node_info.updated = true; | ||
updated_dst_node_info.depend_node_nearest_dst = cur_node; | ||
|
||
// Step 1.1: record a new in-ctrl node | ||
const OpNode* cur_in_ctrl_node = GetNodeFromInCtrlEdge(cur_node); | ||
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_in_ctrl_node); | ||
|
||
// Step 1.2: inherit in-ctrl nodes from Depend OP nodes | ||
const auto& ori_in_ctrl_op_names = cur_op_conf.ctrl_in_op_name(); | ||
for (const std::string& ori_ctrl_in_op_name : ori_in_ctrl_op_names) { | ||
updated_dst_node_info.new_in_ctrl_nodes.emplace_back( | ||
node_info_with_update[ori_ctrl_in_op_name].node); | ||
} | ||
if (updated_cur_node_info.updated) { | ||
std::vector<const OpNode*>& inherit_in_ctrl_nodes = updated_cur_node_info.new_in_ctrl_nodes; | ||
for (const OpNode* inherit_in_ctrl_node : inherit_in_ctrl_nodes) { | ||
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(inherit_in_ctrl_node); | ||
} | ||
} | ||
|
||
// Step 1.3 process src nodes | ||
const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node); | ||
if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果可以边遍历边改图,这段逻辑可以去掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因缺少API支持,这种写法比较困难。 |
||
// "cur_node" and "dst_node" are all Depend OP nodes, and their connection is like this | ||
// other_node cur_node | ||
// \ / | ||
// dst_node | ||
// in this case, all src nodes of "cur_node" should be seen as in-ctrl nodes | ||
if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { | ||
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(updated_cur_node_info.new_src_node); | ||
} | ||
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_src_node); | ||
} else { | ||
if (!IsDependyOp(cur_src_node->op().op_conf())) { | ||
updated_dst_node_info.new_src_node = cur_src_node; | ||
updated_dst_node_info.depend_node_nearest_src = cur_node; | ||
} else if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { | ||
updated_dst_node_info.new_src_node = updated_cur_node_info.new_src_node; | ||
updated_dst_node_info.depend_node_nearest_src = | ||
updated_cur_node_info.depend_node_nearest_src; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Step 2: extract modification info | ||
// including new connection and to delete nodes | ||
std::vector<std::string> del_node_names; | ||
HashMap<std::string, OperatorConf> to_update_op_confs; | ||
for (const auto& node_info : node_info_with_update) { | ||
// filter nodes not updated | ||
if (!node_info.second.updated) { continue; } | ||
const OpNode* cur_node = node_info.second.node; | ||
const std::string& cur_op_name = cur_node->op().op_name(); | ||
// filter Depnd nodes | ||
if (IsDependyOp(cur_node->op().op_conf())) { | ||
del_node_names.emplace_back(cur_op_name); | ||
continue; | ||
} | ||
|
||
const Operator& cur_op = cur_node->op(); | ||
auto iter = to_update_op_confs.find(node_info.first); | ||
if (iter == to_update_op_confs.end()) { | ||
iter = to_update_op_confs.emplace(node_info.first, cur_op.op_conf()).first; | ||
} | ||
OperatorConf& cur_op_conf = iter->second; | ||
|
||
// Step 2.1: connect updated src_node with cur_node (dst_node of Depned OP) | ||
const OpNode* src_node = node_info.second.new_src_node; | ||
const OpNode* depend_node_nearest_dst = node_info.second.depend_node_nearest_dst; | ||
const OpNode* depend_node_nearest_src = node_info.second.depend_node_nearest_src; | ||
CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src); | ||
const auto& old_lbi = | ||
depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn()); | ||
const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是直接用depend_node_nearest_src的输入就可以了。不过目前这样也没有问题就是了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以。不过,148~168行的逻辑(Step 1.3)涉及对src_node的更新,去掉对src_node的记录会显得这段逻辑不那么自然…… |
||
for (const std::string& ibn : cur_node->op().input_bns()) { | ||
if (cur_op.BnInOp2Lbi(ibn) == old_lbi) { | ||
const auto& old_val = | ||
ReplaceInputLbnInOpCustomizedConf(&cur_op_conf, ibn, GenLogicalBlobName(new_lbi)); | ||
CHECK_EQ(GenLogicalBlobName(old_lbi), old_val); | ||
VLOG(3) << "Update input edge, Src Node: " << src_node->op().op_name() | ||
<< "\t->\tDst Node: " << cur_op_name; | ||
} | ||
} | ||
|
||
// Step 2.2: add in-ctrl OPs | ||
const auto& existed_ctrl_in_op_names = cur_op_conf.ctrl_in_op_name(); | ||
for (const OpNode* in_ctrl_node : node_info.second.new_in_ctrl_nodes) { | ||
// filter Depnd nodes | ||
if (IsDependyOp(in_ctrl_node->op().op_conf())) { continue; } | ||
CHECK(cur_node != in_ctrl_node); // self-loop found | ||
const std::string& new_ctrl_in_op_name = in_ctrl_node->op().op_name(); | ||
auto existed_it = std::find(existed_ctrl_in_op_names.begin(), existed_ctrl_in_op_names.end(), | ||
new_ctrl_in_op_name); | ||
// filter src node or duplicate in-ctrl nodes | ||
if (in_ctrl_node != src_node && existed_it == existed_ctrl_in_op_names.end()) { | ||
cur_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name); | ||
VLOG(3) << "Add in-ctrl edge, Src Node: " << new_ctrl_in_op_name | ||
<< "\t->\tDst Node: " << cur_op_name; | ||
} | ||
} | ||
} | ||
|
||
// Step 3: apply modification to job | ||
JobBuilder job_builder(job); | ||
for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); } | ||
job_builder.DelOps(del_node_names); | ||
return Maybe<void>::Ok(); | ||
}; | ||
|
||
} // namespace | ||
|
||
REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass); | ||
|
||
} // namespace oneflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数签名要优化下么?比如Tensor depend_tensor是不是直接叫depend就好了,这里会考虑和一个list of tensor建立控制边么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1)Tensor depend_tensor 已重命名为 depend;
(2)已支持传入depend的类型为Tensor或List[Tensor],并为List[Tensor]的情形追加了测试样例。