Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oneflow.nn.functional.depend api #9807

Merged
merged 31 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7fcdb35
add a new OP: oneflow.nn.functional.depend
Jan 18, 2023
7245f9f
reformat .py file and remove unused import
Jan 18, 2023
1b5db60
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 24, 2023
1de065a
Fix bugs in prune_depend_op_pass
PYNing Jan 30, 2023
ddddbf6
Add UnitTest and DocTest for Depend OP
PYNing Jan 30, 2023
137cfd8
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 30, 2023
354b007
fix license and python format
PYNing Jan 30, 2023
2379f64
fix C++ format
PYNing Jan 30, 2023
ad334bc
fix error in static analysis
PYNing Jan 30, 2023
80e094e
adjust pass order for potential better optimization
PYNing Jan 31, 2023
fd13735
Fix autograd of depend OP
PYNing Jan 31, 2023
cd592da
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 31, 2023
994226d
Simplify logic in PruneDependOpPass
PYNing Jan 31, 2023
5ff01d2
fix logic in PruneDependOpPass and refine comments
PYNing Jan 31, 2023
0ba2e63
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 31, 2023
25cc35c
fix cpp format error
PYNing Feb 1, 2023
684c882
add log for PruneDependOpPass debug
PYNing Feb 1, 2023
feed392
fix logic of PruneDependOpPass
PYNing Feb 1, 2023
56fb105
Refactor code of PruneDependOpPass for Readability
PYNing Feb 1, 2023
8c34c5f
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 3, 2023
fa2efad
Refactor code of PruneDependOpPass for Readability
PYNing Feb 3, 2023
e72c1da
enhance: consider source node has multiple outputs
PYNing Feb 6, 2023
ba854bd
add two more tests to TestDependGraph
PYNing Feb 6, 2023
5d35a37
add one more test case to TestDependGraph
PYNing Feb 6, 2023
439e229
reformat python file
PYNing Feb 6, 2023
0e42ddf
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 6, 2023
4a7ee4f
rename the second parameter
PYNing Feb 7, 2023
3c4d87c
support multiple tensors form different OP
PYNing Feb 7, 2023
7829fb2
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 7, 2023
7ce7f52
Merge branch 'master' into add_depend_op
PYNing Feb 7, 2023
4dce37b
Merge branch 'master' into add_depend_op
PYNing Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions oneflow/core/autograd/gradient_funcs/depand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct DependCaptureState : public AutoGradCaptureState {
bool in_requires_grad = false;
bool depend_tensor_requires_grad = false;
Shape depend_tensor_shape;
Symbol<DType> depend_tensor_dtype;
Maybe<Symbol<Device>> depend_tensor_device;
};

class Depend : public OpExprGradFunction<DependCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(DependCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->in_requires_grad = inputs.at(0)->requires_grad();
ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad();
if (ctx->depend_tensor_requires_grad) {
ctx->depend_tensor_shape = *(inputs.at(1)->shape());
ctx->depend_tensor_dtype = inputs.at(1)->dtype();
ctx->depend_tensor_device = inputs.at(1)->device();
}
return Maybe<void>::Ok();
}

Maybe<void> Apply(const DependCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); }
if (ctx->depend_tensor_requires_grad) {
in_grads->at(1) =
JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), ctx->depend_tensor_dtype,
JUST(ctx->depend_tensor_device)));
}
return Maybe<void>::Ok();
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend);

} // namespace one
} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,10 @@
signature: "Tensor (Tensor input) => IsFinite"
bind_python: True

- name: "depend"
signature: "Tensor (Tensor input, Tensor depend_tensor) => Depend"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数签名要优化下么?比如Tensor depend_tensor是不是直接叫depend就好了,这里会考虑和一个list of tensor建立控制边么

Copy link
Contributor Author

@PYNing PYNing Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1)Tensor depend_tensor 已重命名为 depend
(2)已支持传入depend的类型为Tensor或List[Tensor],并为List[Tensor]的情形追加了测试样例。

bind_python: True

- name: "roc_auc_score"
signature: "Tensor (Tensor label, Tensor pred) => RocAucScore"
bind_python: True
Expand Down
16 changes: 16 additions & 0 deletions oneflow/core/functional/impl/util_ops_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,29 @@ class IsFiniteFunctor final : public UtilOpsFunctor {
}
};

class DependFunctor {
public:
DependFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,
const std::shared_ptr<one::Tensor>& depend_tensor) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {in, depend_tensor});
}

private:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

using namespace impl;

ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsNanFunctor>("IsNan"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsInfFunctor>("IsInf"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsFiniteFunctor>("IsFinite"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<DependFunctor>("Depend"); };

} // namespace functional
} // namespace one
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,9 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
#ifdef WITH_CUDA
JUST(DoPass("AutoMixedPrecision"));
#endif
// prune depend OP and and add ctrl_in_op to op_conf accordingly
// to express the same semantics and avoid performance loss
JUST(DoPass("PruneDependOpPass"));
JUST(DoPass("PruneAmpWhiteIdentityOpPass"));
JUST(DoPass("OptimizerPlacementOptimizationPass"));
// run FuseAddToOutputPass before IRRoundTripBeforeAD since add_2 maybe
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ message JobConfigProto {
optional bool prune_parallel_cast_ops = 509 [default = true];
optional bool prune_cast_to_static_shape_ops = 510 [default = true];
optional bool prune_amp_white_identity_ops = 511 [default = true];
optional bool prune_depend_ops = 512 [default = true];

optional bool cudnn_conv_enable_pseudo_half = 600 [default = true];
optional bool enable_auto_mixed_precision = 602 [default = false];
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class JobDesc final {
bool prune_parallel_cast_ops() const { return job_conf_.prune_parallel_cast_ops(); }
bool prune_cast_to_static_shape_ops() const { return job_conf_.prune_cast_to_static_shape_ops(); }
bool prune_amp_white_identity_ops() const { return job_conf_.prune_amp_white_identity_ops(); }
bool prune_depend_ops() const { return job_conf_.prune_depend_ops(); }
bool enable_auto_parallel() const { return job_conf_.enable_auto_parallel(); }
int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); }

Expand Down
241 changes: 241 additions & 0 deletions oneflow/core/job_rewriter/prune_depend_op_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <string>
#include <vector>
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/logical_blob_id.pb.h"

namespace oneflow {

namespace {

struct UpdatedNodeInfo {
const OpNode* node = nullptr;
const OpNode* new_src_node = nullptr;
const OpNode* depend_node_nearest_src = nullptr;
const OpNode* depend_node_nearest_dst = nullptr;
std::vector<const OpNode*> new_in_ctrl_nodes;
bool updated = false;
};

bool IsDependyOp(const OperatorConf& op) {
return op.has_user_conf() && (op.user_conf().op_type_name() == "depend");
}

bool NeedDoPass(const Job& job) {
return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp);
}

const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node,
const std::string& target_tensor_name) {
CHECK(IsDependyOp(op_node->op().op_conf()));
for (const OpEdge* in_edge : op_node->in_edges()) {
const OpNode* in_op_node = in_edge->src_node();
const std::string& in_op_node_name = in_op_node->op().op_name();
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns = in_edge->lbi2ibns();

for (const auto& item : lbi2ibns) {
const std::string& lbi_op_name = item.first.op_name();
for (const std::string& tensor_name : item.second) {
if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) {
return in_op_node;
}
}
}
}
return nullptr;
}

const OpNode* GetNodeFromInputEdge(const OpNode* op_node) {
return GetNodeFromEdgeByTensorName(op_node, "in_0");
}

const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) {
return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0");
}

LogicalBlobId GetNewLbi(const OpNode* src_node, const OpNode* depend_node_nearest_src) {
CHECK(IsDependyOp(depend_node_nearest_src->op().op_conf()));
for (const OpEdge* out_edge : src_node->out_edges()) {
const OpNode* dst_node = out_edge->dst_node();
if (dst_node != depend_node_nearest_src) { continue; }

CHECK(out_edge->lbis().size() == 1);
return out_edge->lbis()[0];
}
// should not reach here
CHECK(false);
return {};
}

class PruneDependOpPass final : public JobPass {
public:
PruneDependOpPass() = default;
~PruneDependOpPass() override = default;

Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;
};

Maybe<void> PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const {
if (!ctx->job_desc().prune_depend_ops()) { return Maybe<void>::Ok(); }
if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);

HashMap<std::string, UpdatedNodeInfo> node_info_with_update;
std::vector<const OpNode*> ordered_nodes;

// Step 0: topological sort, setup a map for recording modification
op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) {
UpdatedNodeInfo node_info;
node_info.node = node;
node_info_with_update.emplace(node->op().op_name(), node_info);
ordered_nodes.emplace_back(node);
});

// Step 1: process node by topological order
// record modification info when meet Depend OP nodes
for (const OpNode* cur_node : ordered_nodes) {
const std::string& cur_op_name = cur_node->op().op_name();
const OperatorConf& cur_op_conf = cur_node->op().op_conf();
if (!IsDependyOp(cur_op_conf)) { continue; }

// record modification info to each dst_node
for (const OpEdge* out_edge : cur_node->out_edges()) {
const OpNode* dst_node = out_edge->dst_node();
const Operator& dst_op = dst_node->op();

UpdatedNodeInfo& updated_dst_node_info = node_info_with_update.find(dst_op.op_name())->second;
UpdatedNodeInfo& updated_cur_node_info = node_info_with_update.find(cur_op_name)->second;
updated_dst_node_info.updated = true;
updated_dst_node_info.depend_node_nearest_dst = cur_node;

// Step 1.1: record a new in-ctrl node
const OpNode* cur_in_ctrl_node = GetNodeFromInCtrlEdge(cur_node);
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_in_ctrl_node);

// Step 1.2: inherit in-ctrl nodes from Depend OP nodes
const auto& ori_in_ctrl_op_names = cur_op_conf.ctrl_in_op_name();
for (const std::string& ori_ctrl_in_op_name : ori_in_ctrl_op_names) {
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(
node_info_with_update[ori_ctrl_in_op_name].node);
}
if (updated_cur_node_info.updated) {
std::vector<const OpNode*>& inherit_in_ctrl_nodes = updated_cur_node_info.new_in_ctrl_nodes;
for (const OpNode* inherit_in_ctrl_node : inherit_in_ctrl_nodes) {
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(inherit_in_ctrl_node);
}
}

// Step 1.3 process src nodes
const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node);
if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果可以边遍历边改图,这段逻辑可以去掉

Copy link
Contributor Author

@PYNing PYNing Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因缺少API支持,这种写法比较困难。
参考已有的较为简单的Pass,比如EliminateDeadNodesPassPruneAmpWhiteIdentityOpPass,它们理论上可以边遍历边改图,但没有这样做,而是走产生OpGraph - > 分析OpGraph 并记录变更->根据变更修改Job对象的流程。

// "cur_node" and "dst_node" are all Depend OP nodes, and their connection is like this
// other_node cur_node
// \ /
// dst_node
// in this case, all src nodes of "cur_node" should be seen as in-ctrl nodes
if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) {
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(updated_cur_node_info.new_src_node);
}
updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_src_node);
} else {
if (!IsDependyOp(cur_src_node->op().op_conf())) {
updated_dst_node_info.new_src_node = cur_src_node;
updated_dst_node_info.depend_node_nearest_src = cur_node;
} else if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) {
updated_dst_node_info.new_src_node = updated_cur_node_info.new_src_node;
updated_dst_node_info.depend_node_nearest_src =
updated_cur_node_info.depend_node_nearest_src;
}
}
}
}

// Step 2: extract modification info
// including new connection and to delete nodes
std::vector<std::string> del_node_names;
HashMap<std::string, OperatorConf> to_update_op_confs;
for (const auto& node_info : node_info_with_update) {
// filter nodes not updated
if (!node_info.second.updated) { continue; }
const OpNode* cur_node = node_info.second.node;
const std::string& cur_op_name = cur_node->op().op_name();
// filter Depnd nodes
if (IsDependyOp(cur_node->op().op_conf())) {
del_node_names.emplace_back(cur_op_name);
continue;
}

const Operator& cur_op = cur_node->op();
auto iter = to_update_op_confs.find(node_info.first);
if (iter == to_update_op_confs.end()) {
iter = to_update_op_confs.emplace(node_info.first, cur_op.op_conf()).first;
}
OperatorConf& cur_op_conf = iter->second;

// Step 2.1: connect updated src_node with cur_node (dst_node of Depned OP)
const OpNode* src_node = node_info.second.new_src_node;
const OpNode* depend_node_nearest_dst = node_info.second.depend_node_nearest_dst;
const OpNode* depend_node_nearest_src = node_info.second.depend_node_nearest_src;
CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src);
const auto& old_lbi =
depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn());
const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是直接用depend_node_nearest_src的输入就可以了。不过目前这样也没有问题就是了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以。不过,148~168行的逻辑(Step 1.3)涉及对src_node的更新,去掉对src_node的记录会显得这段逻辑不那么自然……

for (const std::string& ibn : cur_node->op().input_bns()) {
if (cur_op.BnInOp2Lbi(ibn) == old_lbi) {
const auto& old_val =
ReplaceInputLbnInOpCustomizedConf(&cur_op_conf, ibn, GenLogicalBlobName(new_lbi));
CHECK_EQ(GenLogicalBlobName(old_lbi), old_val);
VLOG(3) << "Update input edge, Src Node: " << src_node->op().op_name()
<< "\t->\tDst Node: " << cur_op_name;
}
}

// Step 2.2: add in-ctrl OPs
const auto& existed_ctrl_in_op_names = cur_op_conf.ctrl_in_op_name();
for (const OpNode* in_ctrl_node : node_info.second.new_in_ctrl_nodes) {
// filter Depnd nodes
if (IsDependyOp(in_ctrl_node->op().op_conf())) { continue; }
CHECK(cur_node != in_ctrl_node); // self-loop found
const std::string& new_ctrl_in_op_name = in_ctrl_node->op().op_name();
auto existed_it = std::find(existed_ctrl_in_op_names.begin(), existed_ctrl_in_op_names.end(),
new_ctrl_in_op_name);
// filter src node or duplicate in-ctrl nodes
if (in_ctrl_node != src_node && existed_it == existed_ctrl_in_op_names.end()) {
cur_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name);
VLOG(3) << "Add in-ctrl edge, Src Node: " << new_ctrl_in_op_name
<< "\t->\tDst Node: " << cur_op_name;
}
}
}

// Step 3: apply modification to job
JobBuilder job_builder(job);
for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); }
job_builder.DelOps(del_node_names);
return Maybe<void>::Ok();
};

} // namespace

REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass);

} // namespace oneflow
Loading