Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oneflow.nn.functional.depend api #9807

Merged
merged 31 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7fcdb35
add a new OP: oneflow.nn.functional.depend
Jan 18, 2023
7245f9f
reformat .py file and remove unused import
Jan 18, 2023
1b5db60
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 24, 2023
1de065a
Fix bugs in prune_depend_op_pass
PYNing Jan 30, 2023
ddddbf6
Add UnitTest and DocTest for Depend OP
PYNing Jan 30, 2023
137cfd8
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 30, 2023
354b007
fix license and python format
PYNing Jan 30, 2023
2379f64
fix C++ format
PYNing Jan 30, 2023
ad334bc
fix error in static analysis
PYNing Jan 30, 2023
80e094e
adjust pass order for potential better optimization
PYNing Jan 31, 2023
fd13735
Fix autograd of depend OP
PYNing Jan 31, 2023
cd592da
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 31, 2023
994226d
Simplify logic in PruneDependOpPass
PYNing Jan 31, 2023
5ff01d2
fix logic in PruneDependOpPass and refine comments
PYNing Jan 31, 2023
0ba2e63
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Jan 31, 2023
25cc35c
fix cpp format error
PYNing Feb 1, 2023
684c882
add log for PruneDependOpPass debug
PYNing Feb 1, 2023
feed392
fix logic of PruneDependOpPass
PYNing Feb 1, 2023
56fb105
Refactor code of PruneDependOpPass for Readability
PYNing Feb 1, 2023
8c34c5f
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 3, 2023
fa2efad
Refactor code of PruneDependOpPass for Readability
PYNing Feb 3, 2023
e72c1da
enhance: consider source node has multiple outputs
PYNing Feb 6, 2023
ba854bd
add two more tests to TestDependGraph
PYNing Feb 6, 2023
5d35a37
add one more test case to TestDependGraph
PYNing Feb 6, 2023
439e229
reformat python file
PYNing Feb 6, 2023
0e42ddf
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 6, 2023
4a7ee4f
rename the second parameter
PYNing Feb 7, 2023
3c4d87c
support multiple tensors form different OP
PYNing Feb 7, 2023
7829fb2
Merge branch 'Oneflow-Inc:master' into add_depend_op
PYNing Feb 7, 2023
7ce7f52
Merge branch 'master' into add_depend_op
PYNing Feb 7, 2023
4dce37b
Merge branch 'master' into add_depend_op
PYNing Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions oneflow/core/autograd/gradient_funcs/depand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct DependCaptureState : public AutoGradCaptureState {
bool in_requires_grad = false;
bool depend_tensor_requires_grad = false;
Shape depend_tensor_shape;
};

class Depend : public OpExprGradFunction<DependCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(DependCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->in_requires_grad = inputs.at(0)->requires_grad();
ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad();
if (ctx->depend_tensor_requires_grad) { ctx->depend_tensor_shape = *(inputs.at(1)->shape()); }
return Maybe<void>::Ok();
}

Maybe<void> Apply(const DependCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); }
if (ctx->depend_tensor_requires_grad) {
in_grads->at(1) =
JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), out_grads.at(0)->dtype(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果要实现反向的话,dtype和device应该和depend_tensor一样吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新代码,depend_tensor梯度的dtype和device与depend_tensor一致。

JUST(out_grads.at(0)->device())));
}
return Maybe<void>::Ok();
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend);

} // namespace one
} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,10 @@
signature: "Tensor (Tensor input) => IsFinite"
bind_python: True

- name: "depend"
signature: "Tensor (Tensor input, Tensor depend_tensor) => Depend"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数签名要优化下么?比如Tensor depend_tensor是不是直接叫depend就好了,这里会考虑和一个list of tensor建立控制边么

Copy link
Contributor Author

@PYNing PYNing Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1)Tensor depend_tensor 已重命名为 depend
(2)已支持传入depend的类型为Tensor或List[Tensor],并为List[Tensor]的情形追加了测试样例。

bind_python: True

- name: "roc_auc_score"
signature: "Tensor (Tensor label, Tensor pred) => RocAucScore"
bind_python: True
Expand Down
16 changes: 16 additions & 0 deletions oneflow/core/functional/impl/util_ops_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,29 @@ class IsFiniteFunctor final : public UtilOpsFunctor {
}
};

class DependFunctor {
public:
DependFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,
const std::shared_ptr<one::Tensor>& depend_tensor) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {in, depend_tensor});
}

private:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

using namespace impl;

ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsNanFunctor>("IsNan"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsInfFunctor>("IsInf"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<IsFiniteFunctor>("IsFinite"); };
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<DependFunctor>("Depend"); };

} // namespace functional
} // namespace one
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,9 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
// pinned identity can be pruned since GenerateOptimizerOpConfs pass has
// already construct a complete computational graph
JUST(DoPass("PrunePinnedIdentityOpPass"));
// prune depend OP and and add ctrl_in_op to op_conf accordingly
// to express the same semantics and avoid performance loss
JUST(DoPass("PruneDependOpPass"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥放到这个位置,而不是更前面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新代码,将PruneDependOpPass提前到PruneAmpWhiteIdentityOpPass前。
理由:
(1)将PruneDepend尽早执行,可以发掘更多的算子优化空间(如删除Depend OP后可能满足FuseAddToOutputPass的执行条件);
(2)但在前面的部分Pass在删除或更新OP时未考虑控制边的转移或保持(如EliminateDeadNodesPassAutoMixedPrecision)。如果放在它们之前执行,新添加的控制边可能丢失导致失效。

经阅读前面的Pass代码和测试,将PruneDependOpPass的执行提前到PruneAmpWhiteIdentityOpPass之前比较合适。

JUST(DoPass("ReplaceEmbeddingOps"));
JUST(DoPass("SequentialOneEmbeddingOpsPass"));
JUST(DoPass("FuseEmbeddingShuffleInteractionPass"));
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ message JobConfigProto {
optional bool prune_parallel_cast_ops = 509 [default = true];
optional bool prune_cast_to_static_shape_ops = 510 [default = true];
optional bool prune_amp_white_identity_ops = 511 [default = true];
optional bool prune_depend_ops = 512 [default = true];

optional bool cudnn_conv_enable_pseudo_half = 600 [default = true];
optional bool enable_auto_mixed_precision = 602 [default = false];
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/job_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class JobDesc final {
bool prune_parallel_cast_ops() const { return job_conf_.prune_parallel_cast_ops(); }
bool prune_cast_to_static_shape_ops() const { return job_conf_.prune_cast_to_static_shape_ops(); }
bool prune_amp_white_identity_ops() const { return job_conf_.prune_amp_white_identity_ops(); }
bool prune_depend_ops() const { return job_conf_.prune_depend_ops(); }
bool enable_auto_parallel() const { return job_conf_.enable_auto_parallel(); }
int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); }

Expand Down
249 changes: 249 additions & 0 deletions oneflow/core/job_rewriter/prune_depend_op_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <string>
#include <vector>
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/logical_blob_id.pb.h"

namespace oneflow {

namespace {

struct RelativeNodes {
const OpNode* input_node = nullptr;
const OpNode* output_node = nullptr;
const OpNode* nearest_del_node = nullptr;
std::vector<const OpNode*> in_ctrl_nodes = {};
};

bool IsDependyOp(const OperatorConf& op) {
return op.has_user_conf() && (op.user_conf().op_type_name() == "depend");
}

bool NeedDoPass(const Job& job) {
return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp);
}

const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node,
const std::string& target_tensor_name) {
CHECK(IsDependyOp(op_node->op().op_conf()));
for (const OpEdge* in_edge : op_node->in_edges()) {
const OpNode* in_op_node = in_edge->src_node();
const std::string& in_op_node_name = in_op_node->op().op_name();
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns = in_edge->lbi2ibns();

for (const auto& item : lbi2ibns) {
const std::string& lbi_op_name = item.first.op_name();
for (const std::string& tensor_name : item.second) {
if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) {
return in_op_node;
}
}
}
}
return nullptr;
}

const OpNode* GetNodeFromInputEdge(const OpNode* op_node) {
return GetNodeFromEdgeByTensorName(op_node, "in_0");
}

const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) {
return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0");
}

bool IsDependOPNodeAtTop(const OpNode* op_node, HashSet<const OpNode*>& del_nodes) {
CHECK(IsDependyOp(op_node->op().op_conf()));
const OpNode* input_op_node = GetNodeFromInputEdge(op_node);
const OpNode* in_ctrl_op_node = GetNodeFromInCtrlEdge(op_node);
if (del_nodes.find(input_op_node) == del_nodes.end()
&& del_nodes.find(in_ctrl_op_node) == del_nodes.end()) {
return true;
} else {
return false;
}
}

void GetRelativeNodesHelper(const OpNode* op_node, const HashSet<const OpNode*>& del_nodes,
const OpNode* input_node, std::vector<const OpNode*> in_ctrl_nodes,
std::vector<RelativeNodes>& ret) {
CHECK(IsDependyOp(op_node->op().op_conf()));
for (const OpEdge* out_edge : op_node->out_edges()) {
const OpNode* out_op_node = out_edge->dst_node();
if (del_nodes.find(out_op_node) == del_nodes.end()) {
// "out_op_node" is one of valid output nodes

// in this case, record the nodes as result
const OpNode* in_ctrl_node_to_check = GetNodeFromInCtrlEdge(op_node);
if (del_nodes.find(in_ctrl_node_to_check) == del_nodes.end()) {
in_ctrl_nodes.emplace_back(in_ctrl_node_to_check);
}

const OpNode* input_node_to_check = GetNodeFromInputEdge(op_node);
if (del_nodes.find(input_node_to_check) == del_nodes.end()) {
// should not have two input nodes for a depend OP Chain
CHECK(input_node == nullptr);
input_node = input_node_to_check;
}
ret.push_back({input_node, out_op_node, op_node, in_ctrl_nodes});
} else if (op_node == GetNodeFromInCtrlEdge(out_op_node)) {
// "out_op_node" is ALSO a depend OP Node, and "op_node" is an in-control OP Node

// in this case, two precursor node of "op_node" should be interpreted as in-control OP Node
// the "input_node" should not NOT be the precursor of the target output node
// thus, put "input_node into" "depend_nodes", and set "input_node" as NULL in subsequent
// processing
if (input_node) in_ctrl_nodes.push_back(input_node);
input_node = nullptr;
// continue recursion until the target output node is found
GetRelativeNodesHelper(out_op_node, del_nodes, input_node, in_ctrl_nodes, ret);
} else {
// "out_op_node" is ALSO a depend OP Node, and "op_node" is an input OP Node

// in this case, "input_node" should be the real precursor of the target output node
// thus, append in-ctrl OP Node into "in_ctrl_nodes", and update or remain "input_node"
// in subsequent processing
const OpNode* in_ctrl_node_to_check = GetNodeFromInCtrlEdge(op_node);
if (del_nodes.find(in_ctrl_node_to_check) == del_nodes.end()) {
in_ctrl_nodes.emplace_back(in_ctrl_node_to_check);
}

const OpNode* input_node_to_check = GetNodeFromInputEdge(op_node);
if (del_nodes.find(input_node_to_check) == del_nodes.end()) {
// should not have two input nodes for a depend OP Chain
CHECK(input_node == nullptr);
input_node = input_node_to_check;
}
// continue recursion until the target output node is found
GetRelativeNodesHelper(out_op_node, del_nodes, input_node, in_ctrl_nodes, ret);
}
}
}

const std::vector<RelativeNodes> GetRelativeNodes(const OpNode* op_node,
const HashSet<const OpNode*>& del_nodes) {
std::vector<RelativeNodes> ret;
GetRelativeNodesHelper(op_node, del_nodes, nullptr, {}, ret);
return ret;
}

class PruneDependOpPass final : public JobPass {
public:
PruneDependOpPass() = default;
~PruneDependOpPass() override = default;

Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override;
};

Maybe<void> PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const {
if (!ctx->job_desc().prune_depend_ops()) { return Maybe<void>::Ok(); }
if (!NeedDoPass(*job)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);

HashSet<std::string> ctrl_in_op_names;
op_graph.ForEachNode([&](const OpNode* op_node) {
for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) {
ctrl_in_op_names.insert(ctrl_in_op_name);
}
});

HashSet<const OpNode*> del_nodes;
op_graph.ForEachNode([&](const OpNode* op_node) {
const std::string& op_name = op_node->op().op_name();
const OperatorConf& op_conf = op_node->op().op_conf();
// not depend op
if (!IsDependyOp(op_conf)) { return; }
// has ctrl in
if (!op_conf.ctrl_in_op_name().empty()) { return; }
// is ctrl in of another op
if (ctrl_in_op_names.find(op_name) != ctrl_in_op_names.end()) { return; }

del_nodes.insert(op_node);
});

HashMap<std::string, OperatorConf> to_update_op_confs;
std::vector<std::string> del_op_names;
del_op_names.reserve(del_nodes.size());
for (const OpNode* op_node : del_nodes) {
del_op_names.emplace_back(op_node->op().op_name());
// GetRelativeNodes() considers the chain of multiple depend OP Nodes and processes them
// from top to down, so skip the intermediate nodes
if (!IsDependOPNodeAtTop(op_node, del_nodes)) { continue; }
const std::vector<RelativeNodes> relatives = GetRelativeNodes(op_node, del_nodes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段逻辑有点晦涩,有没有一些graph之类的注释,更直观些

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.

This comment was marked as outdated.


// adjust op_conf of nodes connected to depend OP Nodes
for (const RelativeNodes& item : relatives) {
const OpNode* input_node = item.input_node;
const OpNode* output_node = item.output_node;
const OpNode* nearest_del_node = item.nearest_del_node;
const std::vector<const OpNode*>& depend_nodes = item.in_ctrl_nodes;
// in some cases (e.g. the second branch in GetRelativeNodesHelper()), input nodes could
// be interpreted as in-ctrl node, accordingly their input_node will be NULL and the ibn
// modifications should be skip
if (input_node) {
const auto& old_lbi = nearest_del_node->op().BnInOp2Lbi(nearest_del_node->op().SoleObn());
const auto& new_lbi = input_node->op().BnInOp2Lbi(input_node->op().SoleObn());
const Operator& out_op = output_node->op();
for (const std::string& ibn : out_op.input_bns()) {
if (out_op.BnInOp2Lbi(ibn) == old_lbi) {
auto iter = to_update_op_confs.find(out_op.op_name());
if (iter == to_update_op_confs.end()) {
iter = to_update_op_confs.emplace(out_op.op_name(), out_op.op_conf()).first;
}
OperatorConf& op_conf = iter->second;
const auto& old_val =
ReplaceInputLbnInOpCustomizedConf(&op_conf, ibn, GenLogicalBlobName(new_lbi));
CHECK_EQ_OR_RETURN(GenLogicalBlobName(old_lbi), old_val);
}
}
}
// add ctrl_in_op
const Operator& out_op = output_node->op();
auto out_iter = to_update_op_confs.find(out_op.op_name());
if (out_iter == to_update_op_confs.end()) {
out_iter = to_update_op_confs.emplace(out_op.op_name(), out_op.op_conf()).first;
}
OperatorConf& out_op_conf = out_iter->second;
for (const OpNode* node : depend_nodes) {
CHECK(output_node != node); // self-loop found
const auto& existed_ctrl_in_op_names = op_node->op().op_conf().ctrl_in_op_name();
const std::string& new_ctrl_in_op_name = node->op().op_name();
auto existed_it = std::find(existed_ctrl_in_op_names.begin(),
existed_ctrl_in_op_names.end(), new_ctrl_in_op_name);
// avoid adding input node or duplicate control nodes
if (node != input_node && existed_it == existed_ctrl_in_op_names.end()) {
out_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name);
}
}
}
}

JobBuilder job_builder(job);
for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); }
job_builder.DelOps(del_op_names);

return Maybe<void>::Ok();
}

} // namespace

REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass);

} // namespace oneflow
Loading