diff --git a/oneflow/core/autograd/gradient_funcs/depand.cpp b/oneflow/core/autograd/gradient_funcs/depand.cpp new file mode 100644 index 00000000000..784c500c9f7 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/depand.cpp @@ -0,0 +1,65 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct DependCaptureState : public AutoGradCaptureState { + bool in_requires_grad = false; + bool depend_tensor_requires_grad = false; + Shape depend_tensor_shape; + Symbol depend_tensor_dtype; + Maybe> depend_tensor_device; +}; + +class Depend : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { return Maybe::Ok(); } + + Maybe Capture(DependCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 2); // NOLINT(maybe-need-error-msg) + CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) + ctx->in_requires_grad = inputs.at(0)->requires_grad(); + ctx->depend_tensor_requires_grad = inputs.at(1)->requires_grad(); + if (ctx->depend_tensor_requires_grad) { + ctx->depend_tensor_shape = *(inputs.at(1)->shape()); + ctx->depend_tensor_dtype = inputs.at(1)->dtype(); + ctx->depend_tensor_device = inputs.at(1)->device(); + } + return Maybe::Ok(); + } + + Maybe Apply(const DependCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) + in_grads->resize(2); + if (ctx->in_requires_grad) { in_grads->at(0) = out_grads.at(0); } + if (ctx->depend_tensor_requires_grad) { + in_grads->at(1) = + JUST(functional::Constant(ctx->depend_tensor_shape, Scalar(0), ctx->depend_tensor_dtype, + JUST(ctx->depend_tensor_device))); + } + return Maybe::Ok(); + } +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("depend", Depend); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3c1a2cf1035..b5485274a90 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2764,6 +2764,14 @@ signature: "Tensor (Tensor input) => IsFinite" bind_python: True +- name: "depend" + signature: + [ + "Tensor (Tensor input, Tensor depend) => Depend", + "Tensor (Tensor input, TensorTuple depends) => DependTuple", + ] + bind_python: True + - name: "roc_auc_score" signature: "Tensor (Tensor label, Tensor pred) => RocAucScore" bind_python: True diff --git a/oneflow/core/functional/impl/util_ops_functor.cpp b/oneflow/core/functional/impl/util_ops_functor.cpp index a54276c70a7..b2670b56f9d 100644 --- a/oneflow/core/functional/impl/util_ops_functor.cpp +++ b/oneflow/core/functional/impl/util_ops_functor.cpp @@ -60,6 +60,48 @@ class IsFiniteFunctor final : public UtilOpsFunctor { } }; +class DependFunctor { + public: + DependFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& in, + const std::shared_ptr& depend_tensor) const { + return OpInterpUtil::Dispatch(*op_, {in, depend_tensor}); + } + + private: + std::shared_ptr op_; +}; + +class DependTupleFunctor { + public: + DependTupleFunctor() { + ops_.resize(kMaxInputCount); + for (int n = 0; n < ops_.size(); ++n) { + ops_[n] = CHECK_JUST( + one::OpBuilder("depend").Input("in").Input("depend_tensor").Output("out").Build()); + } + } + + Maybe operator()(const std::shared_ptr& in, + const one::TensorTuple& depends) const { + return _dispatch(in, depends, 0); + } + + private: + Maybe _dispatch(const std::shared_ptr& in, const one::TensorTuple& depends, + const int pos) const { + const size_t ndepend = depends.size(); + Maybe output = OpInterpUtil::Dispatch(*ops_[pos], {in, depends[pos]}); + if (pos == ndepend - 1) { return output; } + return _dispatch(JUST(output), depends, pos + 1); + } + + std::vector> ops_; +}; + } // namespace impl using namespace impl; @@ -67,6 +109,8 @@ using namespace impl; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsNan"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsInf"); }; ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("IsFinite"); }; +ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Depend"); }; +ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("DependTuple"); }; } // namespace functional } // namespace one diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 9b010eaba98..9f06215f61f 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -986,6 +986,9 @@ Maybe LazyJobBuildAndInferCtx::Complete() { #ifdef WITH_CUDA JUST(DoPass("AutoMixedPrecision")); #endif + // prune depend OP and and add ctrl_in_op to op_conf accordingly + // to express the same semantics and avoid performance loss + JUST(DoPass("PruneDependOpPass")); JUST(DoPass("PruneAmpWhiteIdentityOpPass")); JUST(DoPass("OptimizerPlacementOptimizationPass")); // run FuseAddToOutputPass before IRRoundTripBeforeAD since add_2 maybe diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 774e1a984f9..7c6f1f0c6e7 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -266,6 +266,7 @@ message JobConfigProto { optional bool prune_parallel_cast_ops = 509 [default = true]; optional bool prune_cast_to_static_shape_ops = 510 [default = true]; optional bool prune_amp_white_identity_ops = 511 [default = true]; + optional bool prune_depend_ops = 512 [default = true]; optional bool cudnn_conv_enable_pseudo_half = 600 [default = true]; optional bool enable_auto_mixed_precision = 602 [default = false]; diff --git a/oneflow/core/job/job_desc.h b/oneflow/core/job/job_desc.h index 4d3446608af..0994dc01934 100644 --- a/oneflow/core/job/job_desc.h +++ b/oneflow/core/job/job_desc.h @@ -62,6 +62,7 @@ class JobDesc final { bool prune_parallel_cast_ops() const { return job_conf_.prune_parallel_cast_ops(); } bool prune_cast_to_static_shape_ops() const { return job_conf_.prune_cast_to_static_shape_ops(); } bool prune_amp_white_identity_ops() const { return job_conf_.prune_amp_white_identity_ops(); } + bool prune_depend_ops() const { return job_conf_.prune_depend_ops(); } bool enable_auto_parallel() const { return job_conf_.enable_auto_parallel(); } int64_t cudnn_buf_limit_mbyte() const { return job_conf_.cudnn_buf_limit_mbyte(); } diff --git a/oneflow/core/job_rewriter/prune_depend_op_pass.cpp b/oneflow/core/job_rewriter/prune_depend_op_pass.cpp new file mode 100644 index 00000000000..8d418c54b82 --- /dev/null +++ b/oneflow/core/job_rewriter/prune_depend_op_pass.cpp @@ -0,0 +1,241 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/graph/node.h" +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job_rewriter/job_pass.h" +#include "oneflow/core/register/logical_blob_id.pb.h" + +namespace oneflow { + +namespace { + +struct UpdatedNodeInfo { + const OpNode* node = nullptr; + const OpNode* new_src_node = nullptr; + const OpNode* depend_node_nearest_src = nullptr; + const OpNode* depend_node_nearest_dst = nullptr; + std::vector new_in_ctrl_nodes; + bool updated = false; +}; + +bool IsDependyOp(const OperatorConf& op) { + return op.has_user_conf() && (op.user_conf().op_type_name() == "depend"); +} + +bool NeedDoPass(const Job& job) { + return std::any_of(job.net().op().cbegin(), job.net().op().cend(), IsDependyOp); +} + +const OpNode* GetNodeFromEdgeByTensorName(const OpNode* op_node, + const std::string& target_tensor_name) { + CHECK(IsDependyOp(op_node->op().op_conf())); + for (const OpEdge* in_edge : op_node->in_edges()) { + const OpNode* in_op_node = in_edge->src_node(); + const std::string& in_op_node_name = in_op_node->op().op_name(); + const HashMap>& lbi2ibns = in_edge->lbi2ibns(); + + for (const auto& item : lbi2ibns) { + const std::string& lbi_op_name = item.first.op_name(); + for (const std::string& tensor_name : item.second) { + if (in_op_node_name == lbi_op_name && tensor_name == target_tensor_name) { + return in_op_node; + } + } + } + } + return nullptr; +} + +const OpNode* GetNodeFromInputEdge(const OpNode* op_node) { + return GetNodeFromEdgeByTensorName(op_node, "in_0"); +} + +const OpNode* GetNodeFromInCtrlEdge(const OpNode* op_node) { + return GetNodeFromEdgeByTensorName(op_node, "depend_tensor_0"); +} + +LogicalBlobId GetNewLbi(const OpNode* src_node, const OpNode* depend_node_nearest_src) { + CHECK(IsDependyOp(depend_node_nearest_src->op().op_conf())); + for (const OpEdge* out_edge : src_node->out_edges()) { + const OpNode* dst_node = out_edge->dst_node(); + if (dst_node != depend_node_nearest_src) { continue; } + + CHECK(out_edge->lbis().size() == 1); + return out_edge->lbis()[0]; + } + // should not reach here + CHECK(false); + return {}; +} + +class PruneDependOpPass final : public JobPass { + public: + PruneDependOpPass() = default; + ~PruneDependOpPass() override = default; + + Maybe Apply(Job* job, JobPassCtx* ctx) const override; +}; + +Maybe PruneDependOpPass::Apply(Job* job, JobPassCtx* ctx) const { + if (!ctx->job_desc().prune_depend_ops()) { return Maybe::Ok(); } + if (!NeedDoPass(*job)) { return Maybe::Ok(); } + const OpGraph op_graph(*job); + + HashMap node_info_with_update; + std::vector ordered_nodes; + + // Step 0: topological sort, setup a map for recording modification + op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { + UpdatedNodeInfo node_info; + node_info.node = node; + node_info_with_update.emplace(node->op().op_name(), node_info); + ordered_nodes.emplace_back(node); + }); + + // Step 1: process node by topological order + // record modification info when meet Depend OP nodes + for (const OpNode* cur_node : ordered_nodes) { + const std::string& cur_op_name = cur_node->op().op_name(); + const OperatorConf& cur_op_conf = cur_node->op().op_conf(); + if (!IsDependyOp(cur_op_conf)) { continue; } + + // record modification info to each dst_node + for (const OpEdge* out_edge : cur_node->out_edges()) { + const OpNode* dst_node = out_edge->dst_node(); + const Operator& dst_op = dst_node->op(); + + UpdatedNodeInfo& updated_dst_node_info = node_info_with_update.find(dst_op.op_name())->second; + UpdatedNodeInfo& updated_cur_node_info = node_info_with_update.find(cur_op_name)->second; + updated_dst_node_info.updated = true; + updated_dst_node_info.depend_node_nearest_dst = cur_node; + + // Step 1.1: record a new in-ctrl node + const OpNode* cur_in_ctrl_node = GetNodeFromInCtrlEdge(cur_node); + updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_in_ctrl_node); + + // Step 1.2: inherit in-ctrl nodes from Depend OP nodes + const auto& ori_in_ctrl_op_names = cur_op_conf.ctrl_in_op_name(); + for (const std::string& ori_ctrl_in_op_name : ori_in_ctrl_op_names) { + updated_dst_node_info.new_in_ctrl_nodes.emplace_back( + node_info_with_update[ori_ctrl_in_op_name].node); + } + if (updated_cur_node_info.updated) { + std::vector& inherit_in_ctrl_nodes = updated_cur_node_info.new_in_ctrl_nodes; + for (const OpNode* inherit_in_ctrl_node : inherit_in_ctrl_nodes) { + updated_dst_node_info.new_in_ctrl_nodes.emplace_back(inherit_in_ctrl_node); + } + } + + // Step 1.3 process src nodes + const OpNode* cur_src_node = GetNodeFromInputEdge(cur_node); + if (IsDependyOp(dst_node->op().op_conf()) && cur_node == GetNodeFromInCtrlEdge(dst_node)) { + // "cur_node" and "dst_node" are all Depend OP nodes, and their connection is like this + // other_node cur_node + // \ / + // dst_node + // in this case, all src nodes of "cur_node" should be seen as in-ctrl nodes + if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { + updated_dst_node_info.new_in_ctrl_nodes.emplace_back(updated_cur_node_info.new_src_node); + } + updated_dst_node_info.new_in_ctrl_nodes.emplace_back(cur_src_node); + } else { + if (!IsDependyOp(cur_src_node->op().op_conf())) { + updated_dst_node_info.new_src_node = cur_src_node; + updated_dst_node_info.depend_node_nearest_src = cur_node; + } else if (updated_cur_node_info.updated && updated_cur_node_info.new_src_node) { + updated_dst_node_info.new_src_node = updated_cur_node_info.new_src_node; + updated_dst_node_info.depend_node_nearest_src = + updated_cur_node_info.depend_node_nearest_src; + } + } + } + } + + // Step 2: extract modification info + // including new connection and to delete nodes + std::vector del_node_names; + HashMap to_update_op_confs; + for (const auto& node_info : node_info_with_update) { + // filter nodes not updated + if (!node_info.second.updated) { continue; } + const OpNode* cur_node = node_info.second.node; + const std::string& cur_op_name = cur_node->op().op_name(); + // filter Depnd nodes + if (IsDependyOp(cur_node->op().op_conf())) { + del_node_names.emplace_back(cur_op_name); + continue; + } + + const Operator& cur_op = cur_node->op(); + auto iter = to_update_op_confs.find(node_info.first); + if (iter == to_update_op_confs.end()) { + iter = to_update_op_confs.emplace(node_info.first, cur_op.op_conf()).first; + } + OperatorConf& cur_op_conf = iter->second; + + // Step 2.1: connect updated src_node with cur_node (dst_node of Depned OP) + const OpNode* src_node = node_info.second.new_src_node; + const OpNode* depend_node_nearest_dst = node_info.second.depend_node_nearest_dst; + const OpNode* depend_node_nearest_src = node_info.second.depend_node_nearest_src; + CHECK(src_node && depend_node_nearest_dst && depend_node_nearest_src); + const auto& old_lbi = + depend_node_nearest_dst->op().BnInOp2Lbi(depend_node_nearest_dst->op().SoleObn()); + const auto new_lbi = GetNewLbi(src_node, depend_node_nearest_src); + for (const std::string& ibn : cur_node->op().input_bns()) { + if (cur_op.BnInOp2Lbi(ibn) == old_lbi) { + const auto& old_val = + ReplaceInputLbnInOpCustomizedConf(&cur_op_conf, ibn, GenLogicalBlobName(new_lbi)); + CHECK_EQ(GenLogicalBlobName(old_lbi), old_val); + VLOG(3) << "Update input edge, Src Node: " << src_node->op().op_name() + << "\t->\tDst Node: " << cur_op_name; + } + } + + // Step 2.2: add in-ctrl OPs + const auto& existed_ctrl_in_op_names = cur_op_conf.ctrl_in_op_name(); + for (const OpNode* in_ctrl_node : node_info.second.new_in_ctrl_nodes) { + // filter Depnd nodes + if (IsDependyOp(in_ctrl_node->op().op_conf())) { continue; } + CHECK(cur_node != in_ctrl_node); // self-loop found + const std::string& new_ctrl_in_op_name = in_ctrl_node->op().op_name(); + auto existed_it = std::find(existed_ctrl_in_op_names.begin(), existed_ctrl_in_op_names.end(), + new_ctrl_in_op_name); + // filter src node or duplicate in-ctrl nodes + if (in_ctrl_node != src_node && existed_it == existed_ctrl_in_op_names.end()) { + cur_op_conf.add_ctrl_in_op_name(new_ctrl_in_op_name); + VLOG(3) << "Add in-ctrl edge, Src Node: " << new_ctrl_in_op_name + << "\t->\tDst Node: " << cur_op_name; + } + } + } + + // Step 3: apply modification to job + JobBuilder job_builder(job); + for (const auto& pair : to_update_op_confs) { job_builder.MutOpsOnlyOnce({pair.second}); } + job_builder.DelOps(del_node_names); + return Maybe::Ok(); +}; + +} // namespace + +REGISTER_JOB_PASS("PruneDependOpPass", PruneDependOpPass); + +} // namespace oneflow diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index ad3b450be57..116b0170127 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -5348,8 +5348,8 @@ def OneFlow_GroupedMatmulBiasOp : OneFlow_BaseOp<"grouped_matmul_bias", [NoSideE #endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS // Group: MISC -// CategoricalOrdinalEncode, add_n, arange, bincount, coin_flip, concat, tensor_constant, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique, unique_with_counts, xdivy_x_grad, xdivy_y_grad, stack, stack_grad, fill_, fill_tensor_, exponential, multinomial_with_replacement, fused_weighted_sum -// Total: 40 +// CategoricalOrdinalEncode, add_n, arange, bincount, coin_flip, concat, tensor_constant, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique, unique_with_counts, xdivy_x_grad, xdivy_y_grad, stack, stack_grad, fill_, fill_tensor_, exponential, multinomial_with_replacement, fused_weighted_sum, depend +// Total: 41 #ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS @@ -6100,6 +6100,20 @@ def OneFlow_FusedWeightedSumOp : OneFlow_BaseOp<"fused_weighted_sum", [NoSideEff let has_data_type_infer_fn = 1; } +def OneFlow_DependOp : OneFlow_BaseOp<"depend", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$depend_tensor + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + #endif // GET_ONEFLOW_MISC_OP_DEFINITIONS // Group: NCCL diff --git a/oneflow/user/kernels/copy_data_content_kernel.cpp b/oneflow/user/kernels/copy_data_content_kernel.cpp index 50e1ad8860a..623ed0c3b6e 100644 --- a/oneflow/user/kernels/copy_data_content_kernel.cpp +++ b/oneflow/user/kernels/copy_data_content_kernel.cpp @@ -84,6 +84,7 @@ REGISTER_COPY_DATA_CONTENT_KERNEL("parallel_cast"); REGISTER_COPY_DATA_CONTENT_KERNEL("hierarchical_parallel_cast"); REGISTER_COPY_DATA_CONTENT_KERNEL("hierarchical_parallel_cast_like"); REGISTER_COPY_DATA_CONTENT_KERNEL("pinned_identity"); +REGISTER_COPY_DATA_CONTENT_KERNEL("depend"); } // namespace diff --git a/oneflow/user/ops/depend_op.cpp b/oneflow/user/ops/depend_op.cpp new file mode 100644 index 00000000000..32a702b26ac --- /dev/null +++ b/oneflow/user/ops/depend_op.cpp @@ -0,0 +1,53 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +/* static */ Maybe DependOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + ctx->SetOutputShape("out", 0, ctx->InputShape("in", 0)); + ctx->SetOutputIsDynamic("out", 0, ctx->InputIsDynamic("in", 0)); + return Maybe::Ok(); +} + +/*static*/ Maybe DependOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DependOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(user_op::OpArg("depend_tensor", 0)) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("depend_tensor", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DependOp::InferDataType(user_op::InferContext* ctx) { + ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/python/oneflow/framework/docstr/__init__.py b/python/oneflow/framework/docstr/__init__.py index 822675cecac..7fdacda753a 100644 --- a/python/oneflow/framework/docstr/__init__.py +++ b/python/oneflow/framework/docstr/__init__.py @@ -83,3 +83,4 @@ from .linalg import * from .index_add import * from .baddbmm import * +from .depend import * diff --git a/python/oneflow/framework/docstr/depend.py b/python/oneflow/framework/docstr/depend.py new file mode 100644 index 00000000000..b1fdb9e2d85 --- /dev/null +++ b/python/oneflow/framework/docstr/depend.py @@ -0,0 +1,57 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow +from oneflow.framework.docstr.utils import add_docstr + +add_docstr( + oneflow._C.depend, + r""" + Add control dependency to guarantee OP A is executed before OP B. + Used to prevent OPs from being rearranged or eliminated during graph compilation. + Args: + input (Tensor): a tensor intended to input OP B + depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP) + Returns: + Tensor: the identity of "input" tensor + Examples: + >>> import oneflow as flow + >>> import oneflow.nn as nn + >>> import oneflow.nn.functional as F + >>> class Model(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.OP_A = nn.Linear(128, 128) + ... self.OP_B = nn.Linear(128, 128) + ... + ... def forward(self, x): + ... x1 = self.OP_A(x) + ... x = F.depend(x, x1) + ... return self.OP_B(x) + ... + >>> model = Model() + >>> class Graph(nn.Graph): + ... def __init__(self) -> None: + ... super().__init__() + ... self.model = model + ... + ... def build(self, x): + ... return self.model(x) + ... + >>> graph = Graph() + >>> x = flow.randn([1, 128], dtype=flow.float32) + >>> y = graph(x) + """, +) diff --git a/python/oneflow/framework/function_util.py b/python/oneflow/framework/function_util.py index 82336da91cb..e1de97984d8 100644 --- a/python/oneflow/framework/function_util.py +++ b/python/oneflow/framework/function_util.py @@ -532,6 +532,17 @@ def set_prune_amp_white_identity_ops(func_desc, value=True): func_desc.job_config_proto.prune_amp_white_identity_ops = value +@oneflow_function_config("prune_depend_ops") +def set_prune_depend_ops(func_desc, value=True): + """Whether prune depend operations or not. + + Args: + func_desc ([type]): [description] + value (bool, optional): [description]. Defaults to True. + """ + func_desc.job_config_proto.prune_depend_ops = value + + @oneflow_function_config("non_distributed_optimizer_group_size_mbyte") def set_non_distributed_optimizer_group_size_mbyte(func_desc, value): print( diff --git a/python/oneflow/nn/functional/__init__.py b/python/oneflow/nn/functional/__init__.py index 5226385ec63..63ba8d542ff 100644 --- a/python/oneflow/nn/functional/__init__.py +++ b/python/oneflow/nn/functional/__init__.py @@ -89,3 +89,4 @@ from .functional_deform_conv import deform_conv2d from oneflow._C import kl_div_loss as kl_div from oneflow._C import gumbel_softmax +from .functional_depend import depend diff --git a/python/oneflow/nn/functional/functional_depend.py b/python/oneflow/nn/functional/functional_depend.py new file mode 100644 index 00000000000..e1cd46007d1 --- /dev/null +++ b/python/oneflow/nn/functional/functional_depend.py @@ -0,0 +1,78 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from oneflow.framework.tensor import Tensor +import oneflow as flow +from typing import Union, List + + +def depend(input: Tensor, depend: Union[Tensor, List[Tensor]]) -> Tensor: + r""" + Add control dependency to guarantee OP A is executed before OP B. + Used to prevent OPs from being rearranged or eliminated during graph compilation. + + Args: + input (Tensor): a tensor intended to input OP B + depend (Tensor or List[Tensor]): one of the output tensors of OP A (support passing in multiple tensors form different OP) + + Returns: + Tensor: the identity of "input" tensor + + Examples: + >>> import oneflow as flow + >>> import oneflow.nn as nn + >>> import oneflow.nn.functional as F + >>> class Model(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.OP_A = nn.Linear(128, 128) + ... self.OP_B = nn.Linear(128, 128) + ... + ... def forward(self, x): + ... x1 = self.OP_A(x) + ... x = F.depend(x, x1) + ... return self.OP_B(x) + ... + >>> model = Model() + >>> class Graph(nn.Graph): + ... def __init__(self) -> None: + ... super().__init__() + ... self.model = model + ... + ... def build(self, x): + ... return self.model(x) + ... + >>> graph = Graph() + >>> x = flow.randn([1, 128], dtype=flow.float32) + >>> y = graph(x) + """ + # avoid performance loss in eager mode + if not input.is_lazy: + return input + + # avoid self-loop + if isinstance(depend, Tensor) and input is depend: + raise RuntimeError('"input" and "depend" can NOT be the same tensor.') + + if isinstance(depend, List): + for idx, t_depend in enumerate(depend): + if input is t_depend: + raise RuntimeError( + '"input" and "depend[%d]" are the same tensor, which is not allowed.' + % idx + ) + + return flow._C.depend(input, depend) diff --git a/python/oneflow/test/graph/test_graph_depend.py b/python/oneflow/test/graph/test_graph_depend.py new file mode 100644 index 00000000000..ac73d2bf237 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_depend.py @@ -0,0 +1,227 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import numpy as np +import unittest + +# used to observe operator optimization and execution order manually +# import os +# os.environ["ONEFLOW_DEBUG_MODE"] = "1" +# os.environ["GLOG_v"] = "3" +# os.environ["ENABLE_LOGICAL_CHAIN"] = "true" + +import oneflow as flow +import oneflow.nn as nn +import oneflow.unittest + +# NOTE: nn.functional.depend() behaves differently in the two modes +# in EAGER mode, the OP has no effect. That is, the first paramerter +# and output are the same tensor (like "y=x" in python), while the +# second paramerter will be ignore. + + +def _build_graph_and_test(TestModel, in_data, test_case): + + model = TestModel() + y_eager = model(in_data) + + class TestGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.model = model + + def build(self, x): + return self.model(x) + + graph = TestGraph() + # used to observe operator optimization and execution order manually + # graph.debug(3) + y_lazy = graph(in_data) + test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) + + +@flow.unittest.skip_unless_1n1d() +class TestDependGraph(oneflow.unittest.TestCase): + def test_depend_graph_case0(test_case): + class TestModel_0(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2" be executed before "self.linear(x)" in graph mode + # base use case + x1 = x * 2 + x = nn.functional.depend(x, x1) + x2 = self.linear(x) + return x2 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_0, x, test_case) + + def test_depend_graph_case1(test_case): + class TestModel_1(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode + # test multiple continuous nn.functional.depend() in a logical chain + x1 = x * 2 + x2 = x + 2 + x = nn.functional.depend(x, x1) + x = nn.functional.depend(x, x2) + x3 = self.linear(x) + return x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_1, x, test_case) + + def test_depend_graph_case2(test_case): + class TestModel_2(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode + # some users may code like this + x1 = x * 2 + x2 = x + 2 + x2 = nn.functional.depend(x2, x1) + x = nn.functional.depend(x, x2) + x3 = self.linear(x) + return x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_2, x, test_case) + + def test_depend_graph_case3(test_case): + class TestModel_3(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2", "x + 2" and "x -2" be executed before "self.linear(x)" in graph mode + # a combination of above cases + x1 = x * 2 + x2 = x + 2 + x3 = x - 2 + x = nn.functional.depend(x, x1) + x2 = nn.functional.depend(x2, x3) + x = nn.functional.depend(x, x2) + x3 = self.linear(x) + return x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_3, x, test_case) + + def test_depend_graph_case4(test_case): + class TestModel_4(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # the depend OP do nothing and it should be pruned from graph correctly + x1 = x * 2 + x2 = nn.functional.depend(x, x1) + x3 = self.linear(x) + return x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_4, x, test_case) + + def test_depend_graph_case5(test_case): + class TestModel_5(nn.Module): + def __init__(self): + super().__init__() + self.linear0 = nn.Linear(128, 128) + self.linear1 = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2" be executed before "self.linear0(x)" and + # "self.linear1(x)" in graph mode + # to test the case that depend OP connect to more than one OPs + x1 = x * 2 + x = nn.functional.depend(x, x1) + x2 = self.linear0(x) + x3 = self.linear1(x) + return x2 + x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_5, x, test_case) + + def test_depend_graph_case6(test_case): + class TestModel_6(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x - 2" be executed before "self.linear(x)" in graph mode + # to test the case that the OP connects to Depend OP also connects to other OPs + x1 = x * 2 + x2 = x1 - 2 + x3 = nn.functional.depend(x2, x1) + x4 = self.linear(x3) + x5 = x2 + x4 + return x5 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_6, x, test_case) + + def test_depend_graph_case7(test_case): + class TestModel_7(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # to ensure "mp_values * 2" be executed before "max_pool1d" in graph mode + # to test the case that OPs have mutiple outputs connect to depend OP + x1 = x + 2 + mp_values, mp_indices = nn.functional.max_pool1d( + x, kernel_size=2, return_indices=True + ) + mp_values = nn.functional.depend(mp_values, x1) + mp_values = mp_values * 2 + return mp_values + mp_indices.to(flow.float32) + + x = flow.randn([1, 2, 3], dtype=flow.float32) + _build_graph_and_test(TestModel_7, x, test_case) + + def test_depend_graph_case8(test_case): + class TestModel_1(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 128) + + def forward(self, x): + # to ensure "x * 2" and "x + 2" be executed before "self.linear(x)" in graph mode + # to test the case that inputting mutiple depend tensors at a time + x1 = x * 2 + x2 = x + 2 + x = nn.functional.depend(x, [x1, x2]) + x3 = self.linear(x) + return x3 + + x = flow.randn([1, 128], dtype=flow.float32) + _build_graph_and_test(TestModel_1, x, test_case) + + +if __name__ == "__main__": + unittest.main()