From 95240c20bffd8becf186d39d5fda276bb507d843 Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu Date: Fri, 10 Jun 2022 14:40:47 +0800 Subject: [PATCH] Feat/zero mix with mp (#8036) * add zero limit * add debug * add mix zero test * refactor zero api * zero test with mp * add 2d test * add zero nd * add nd zero * add sbp cast * test passed soft limit consumer * refine size api * zero use stage 2 * add limit consumer api * add new api * refine zero s select * fix index out of range * rm zero limit on device type * zero test with activation checkpointing * add indentity when dp sequence len is 1 * move to base with master * fix * fix * fix * add test * debug bad case * refine test for eager and graph boxing * test case ready * simplify * refine test * fix buff size * fix conflict * refine zero nd * refine * add full test * revert change * refine split check * fix typo * rm log * spit long func * restore test * Update optimizer_placement_optimization_pass.cpp * auto format by CI * auto format by CI * fix static check * add tips for zero api change * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- docs/source/graph.rst | 3 +- oneflow/core/job/eager_nccl_comm_manager.cpp | 3 + oneflow/core/job/job_build_and_infer_ctx.cpp | 6 + oneflow/core/job/job_builder.cpp | 1 + oneflow/core/job/job_conf.proto | 1 + .../optimizer_placement_optimization_pass.cpp | 290 +++++++++++++++--- oneflow/user/ops/stack_op.cpp | 2 - .../oneflow/framework/multi_client_session.py | 3 + python/oneflow/nn/graph/graph_config.py | 133 ++++---- python/oneflow/test/graph/test_graph_zero.py | 182 +++++++++-- .../test/graph/test_optimization_conf.py | 2 +- 11 files changed, 497 insertions(+), 129 deletions(-) diff --git a/docs/source/graph.rst b/docs/source/graph.rst index 5ec08061a8a..270e5a01cf0 100644 --- a/docs/source/graph.rst +++ b/docs/source/graph.rst @@ -20,12 +20,11 @@ Base class for running neural networks in Static Graph Mode. .. autoclass:: oneflow.nn.graph.graph_config.GraphConfig :members: enable_amp, + enable_zero, allow_fuse_model_update_ops, allow_fuse_add_to_output, allow_fuse_cast_scale, set_gradient_accumulation_steps, - set_zero_redundancy_optimizer_mode, - set_zero_redundancy_optimizer_min_size_after_split, enable_cudnn_conv_heuristic_search_algo, :member-order: bysource diff --git a/oneflow/core/job/eager_nccl_comm_manager.cpp b/oneflow/core/job/eager_nccl_comm_manager.cpp index d8b77cdbb72..959a7837010 100644 --- a/oneflow/core/job/eager_nccl_comm_manager.cpp +++ b/oneflow/core/job/eager_nccl_comm_manager.cpp @@ -71,6 +71,9 @@ void CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key, << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank << ", key = {" << key << "}\n"; OF_NCCL_CHECK(ncclCommInitRank(comm, device_vec.size(), nccl_unique_id, rank)); + VLOG(2) << " EagerNcclCommMgr::ncclCommInitRank succeed device_vec.size() = " << device_vec.size() + << ", nccl_unique_id = " << NcclUniqueId2String(nccl_unique_id) << ", rank = " << rank + << ", key = {" << key << "}\n"; } } // namespace diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 5afc24192da..8ae659fd541 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -997,13 +997,19 @@ Maybe LazyJobBuildAndInferCtx::Complete() { } }; int32_t pass_cnt = 0; + const int64_t prev_v = FLAGS_v; auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe { + VLOG(1) << job_name << " is compiling with pass" + << " pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + << (cnt > 0 ? std::to_string(cnt) : ""); if (unlikely(NeedLogJob(pass_name))) { std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-before"); + FLAGS_v = 3; } JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx)); if (unlikely(NeedLogJob(pass_name))) { + FLAGS_v = prev_v; std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; LogJob("pass_cnt_" + std::to_string(pass_cnt) + "-" + pass_name + cnt_str + "-after"); } diff --git a/oneflow/core/job/job_builder.cpp b/oneflow/core/job/job_builder.cpp index a7f81384376..b13bd8a67fd 100644 --- a/oneflow/core/job/job_builder.cpp +++ b/oneflow/core/job/job_builder.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 69aa7ad29f0..03638feec30 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -211,6 +211,7 @@ message JobConfigProto { optional bool enable_gradients_stats_aggregation = 106 [default = true]; optional string optimizer_placement_optimization_mode = 107; optional int64 optimizer_placement_optimization_threshold = 108 [default = 1024]; + optional int64 optimizer_placement_optimization_shard_restore_level = 110 [default = 2]; optional QatConfig qat_config = 109; diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 1ca857fd11f..7aaf2e75426 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -13,10 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "oneflow/core/common/util.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/user_op_conf.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/job_desc.h" +#include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/operator/operator.h" namespace oneflow { @@ -31,7 +40,7 @@ int64_t GetSoleOutBlobSize(const OpNode* node) { class DataParallelNodeSequence final { public: DataParallelNodeSequence(std::vector nodes, int64_t order) - : nodes_(std::move(nodes)), order_(order) { + : nodes_(std::move(nodes)), order_(order), len_(nodes_.size()) { const OpNode* var_node = nodes_.front(); CHECK(var_node->op().op_conf().has_variable_conf()); model_size_ = GetSoleOutBlobSize(var_node); @@ -50,13 +59,23 @@ class DataParallelNodeSequence final { int64_t model_size() const { return model_size_; } + int64_t len() const { return len_; } + + void resize(const int64_t size) { + CHECK(size <= len_); + CHECK(size > 1); + nodes_.resize(size); + len_ = nodes().size(); + } + private: std::vector nodes_; int64_t order_; int64_t model_size_; + int64_t len_; }; -using SequencePtr = std::shared_ptr; +using SequencePtr = std::shared_ptr; ParallelConf NonDistributedParallelConf4ParallelId(const ParallelDesc& pd, const int64_t parallel_id) { @@ -76,7 +95,6 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( // Find sequence like: vairable -> cast_fp32_to_fp16 if (!start->op().op_conf().has_variable_conf()) { return Maybe::Ok(); } const ParallelDesc& pd = start->parallel_desc(); - if (pd.device_type() != DeviceType::kCUDA) { return Maybe::Ok(); } if (pd.parallel_num() == 1) { return Maybe::Ok(); } const OpNode* cur_node = start; while (cur_node != nullptr) { @@ -85,12 +103,21 @@ Maybe GetDataParallelVariableAndNaiveSuccNode( if (cur_node->in_edges().size() > 1) { break; } if (cur_node->op().input_bns().size() != 1) { break; } const std::string& sole_ibn = cur_node->op().SoleIbn(); - if (!cur_node->SbpParallel4BnInOp(sole_ibn).has_broadcast_parallel()) { break; } + const NdSbp& ibn_nd_sbp = cur_node->NdSbp4BnInOp(sole_ibn); + bool has_broadcast = false; + FOR_RANGE(int, i, 0, ibn_nd_sbp.sbp_parallel_size()) { + if (ibn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; + } + if (!has_broadcast) { break; } } - if (!IsAllowed(cur_node)) { break; } if (cur_node->op().output_bns().size() != 1) { break; } const std::string& sole_obn = cur_node->op().SoleObn(); - if (!cur_node->SbpParallel4BnInOp(sole_obn).has_broadcast_parallel()) { break; } + const NdSbp& obn_nd_sbp = cur_node->NdSbp4BnInOp(sole_obn); + bool has_broadcast = false; + FOR_RANGE(int, i, 0, obn_nd_sbp.sbp_parallel_size()) { + if (obn_nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { has_broadcast = true; }; + } + if (!has_broadcast) { break; } out->emplace_back(cur_node); if (cur_node->out_edges().size() == 1) { cur_node = cur_node->SoleOutEdge()->dst_node(); @@ -123,6 +150,79 @@ void SetBroadcastParallel4Consumers(JobBuilder* builder, const SequencePtr& sequ }); } +void SetNdSbp4OpNodeIbn(JobBuilder* builder, const OpNode* node, const std::string& ibn, + const NdSbp& nd_sbp) { + OpBlobArg op_blob_arg; + op_blob_arg.set_op_name(node->op().op_name()); + op_blob_arg.set_bn_in_op(ibn); + builder->SetNdSbp4Oba(op_blob_arg, nd_sbp); +} + +void SetNdSbp4Consumers(JobBuilder* builder, const SequencePtr& sequence, const NdSbp& nd_sbp) { + const OpNode* node = sequence->GetLastNode(); + const LogicalBlobId& lbi = node->op().BnInOp2Lbi(node->op().SoleObn()); + const int64_t shard_restore_level = + builder->job().job_conf().optimizer_placement_optimization_shard_restore_level(); + // If shard_restore_level == 0, no limit on consumer + if (shard_restore_level == 1) { + // Input lbn for parallel cast op + std::string parallel_cast_input_lbn = GenLogicalBlobName(lbi); + // Add indentity to enable mem reuse of boxing op when there is no op between var op and boxing. + if (sequence->len() == 1) { + VLOG(3) << "ZeRO find a data-parallel sequence only has one variable " + << sequence->GetVariableNode()->op().op_name(); + const auto var_identity_op = + user_op::UserOpConfWrapperBuilder("System-ZeRO-Identity-" + node->op().op_name() + "-" + + NewUniqueId()) + .Op("identity") + .Input("in", GenLogicalBlobName(lbi)) + .Output("out") + .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) + .Build(); + builder->AddOps(node->parallel_desc().parallel_conf(), {var_identity_op.op_conf()}); + parallel_cast_input_lbn = var_identity_op.output("out", 0); + } + // Add parallel cast op to make soft limt on consumer to consume weight with Broadcast SBP. + const auto parallel_cast_op = + user_op::UserOpConfWrapperBuilder("System-ZeRO-ParallelCast-" + node->op().op_name() + "-" + + NewUniqueId()) + .Op("hierarchical_parallel_cast") + .Input("in", parallel_cast_input_lbn) + .Output("out") + .Attr>("nd_sbp", NdSbpToStringList(nd_sbp)) + .Attr("grad_mode", "identity") // don't do ndsbp cast at backward + .Attr>("grad_nd_sbp", std::vector()) + .ScopeSymbolId(node->op().op_conf().scope_symbol_id()) + .Build(); + builder->AddOps(node->parallel_desc().parallel_conf(), {parallel_cast_op.op_conf()}); + + // Make consumers to consume parallel cast op + auto out_lbn = parallel_cast_op.output("out", 0); + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + for (const std::string& ibn : out_node->op().input_bns()) { + if (out_node->op().BnInOp2Lbi(ibn) == lbi) { + if (!CHECK_JUST(builder->IsInMutOpTransaction(out_node->op().op_name()))) { + CHECK_JUST(builder->MutOpTransactionMut(out_node->op().op_conf())); + } + OperatorConf& mut_consumer_op = + CHECK_JUST(builder->MutOpTransactionGet(out_node->op().op_name())); + const auto& old_lbn = ReplaceInputLbnInOpCustomizedConf(&mut_consumer_op, ibn, out_lbn); + CHECK_EQ(old_lbn, GenLogicalBlobName(lbi)); + } + } + }); + } else if (shard_restore_level == 2) { + // Hard limt consumer to consume weight as Broadcast. + node->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + for (const std::string& ibn : out_node->op().input_bns()) { + if (out_node->op().BnInOp2Lbi(ibn) == lbi) { + SetNdSbp4OpNodeIbn(builder, out_node, ibn, nd_sbp); + } + } + }); + } +} + std::function MakeGetterOpNode2TopoOrder(const OpGraph& op_graph) { HashMap op_node2topo_order; int64_t node_cnt = 0; @@ -152,7 +252,7 @@ void ForEachDataParallelNodeSequence(const OpGraph& op_graph, CHECK_JUST(GetDataParallelVariableAndNaiveSuccNode(node, IsAllowed, &nodes)); if (nodes.empty()) { return; } const int64_t order = GetMinConsumerOrder(op_graph, nodes.back(), OpNode2Order); - Handler(std::make_shared(std::move(nodes), order)); + Handler(std::make_shared(std::move(nodes), order)); }); } @@ -188,6 +288,24 @@ bool IsS0Parallel(const SbpSignature& signature, const std::string& bn) { return IsS0Parallel(signature.bn_in_op2sbp_parallel().at(bn)); } +bool IsNdSbpMatch(const NdSbpSignature& signature, const std::string& bn, const NdSbp& nd_sbp) { + return signature.bn_in_op2nd_sbp().at(bn) == nd_sbp; +} + +bool IsNdSbpSupported4Op(const OpNode* node, const NdSbp& nd_sbp) { + if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } + std::vector list; + auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { + return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); + }; + CHECK_JUST(node->op().GetNdSbpSignatureList(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); + const auto IsInAndOutMatch = [&](const NdSbpSignature& signature) { + return IsNdSbpMatch(signature, node->op().SoleIbn(), nd_sbp) + && IsNdSbpMatch(signature, node->op().SoleObn(), nd_sbp); + }; + return std::any_of(list.cbegin(), list.cend(), IsInAndOutMatch); +} + bool IsS0SignatureSupported(const OpNode* node) { if (node->op().input_bns().size() != 1 || node->op().output_bns().size() != 1) { return false; } SbpSignatureList list; @@ -222,42 +340,141 @@ void ForEachModelSizeBalancedPartition( } } -Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { - const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); - const auto IsAllowed = [threshold](const OpNode* n) -> bool { - if (n->op().op_conf().has_variable_conf()) { - const Shape shape(n->op().op_conf().variable_conf().shape()); - const int64_t parallel_num = n->parallel_desc().parallel_num(); - // Parameter needs to be able to evenly splited and one slice size >= threshold - return shape.At(0) % parallel_num == 0 && shape.elem_cnt() >= threshold * parallel_num; +namespace { +bool IsSplitValid(const Shape& shape, const NdSbp& nd_sbp, const Shape& hierachy, + int64_t min_size) { + if (shape.NumAxes() < 1 || shape.elem_cnt() < 1) { return false; } + CHECK_EQ(nd_sbp.sbp_parallel_size(), hierachy.NumAxes()); + Shape cur_shape = shape; + if (cur_shape.elem_cnt() < min_size) { return false; } + FOR_RANGE(int64_t, i, 0, hierachy.NumAxes()) { + const auto& sbp = nd_sbp.sbp_parallel(i); + if (sbp.has_split_parallel()) { + const int64_t dim = sbp.split_parallel().axis(); + if (dim >= cur_shape.NumAxes()) { return false; } + // Evenly split. + if (cur_shape.At(dim) % hierachy.At(i) != 0) { return false; } + cur_shape.Set(dim, cur_shape.At(dim) / hierachy.At(i)); + // Larger then min size. + if (cur_shape.elem_cnt() < min_size) { return false; } + } + } + return true; +} + +void GenerateSplitSignature(const NdSbp& var_nd_sbp, const OperatorConf& new_var_op_conf, + std::string& new_split_signature, int64_t& split_dim) { + if (new_var_op_conf.variable_conf().nd_sbp_size() > 0 && NdSbpIsAllBroadcast(var_nd_sbp)) { + // split last dim + split_dim = new_var_op_conf.variable_conf().nd_sbp_size() - 1; + // All B, B -> S0 + new_split_signature = "S(0)"; + } else { + // ND sbp, (*, B, S, *) -> (*, S, S, *) + // ND sbp, (*, S, B, *) -> (*, S, S, *) + FOR_RANGE(int64_t, j, 0, new_var_op_conf.variable_conf().nd_sbp_size()) { + if (new_var_op_conf.variable_conf().nd_sbp(j) == "B") { + std::vector adjacent_dim{j - 1, j + 1}; + for (auto const& dim_to_try : adjacent_dim) { + if (dim_to_try >= 0 && dim_to_try < new_var_op_conf.variable_conf().nd_sbp_size()) { + SbpParallel sbp; + if (ParseSbpParallelFromString(new_var_op_conf.variable_conf().nd_sbp(dim_to_try), &sbp) + && sbp.has_split_parallel()) { + new_split_signature = new_var_op_conf.variable_conf().nd_sbp(dim_to_try); + split_dim = j; + } + } + if (new_split_signature != "") break; + } + } + // Only split one more dim. + if (new_split_signature != "") break; + } + } +} +void ShardSequence(JobBuilder* builder, const int64_t threshold, const ParallelDesc& pd, + std::vector&& sorted_sequences) { + // For all sorted sequnence, set the variable op in the sequence to S + // and add ctrl edge to control the exectuion order between variable ops. + // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass + // consume the fp16 variable and the optimizer consume the fp32 variable. + std::string prev_allowed_op_name = ""; + for (int64_t i = 0; i < sorted_sequences.size(); ++i) { + const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); + OperatorConf new_var_op_conf = var_node->op().op_conf(); + const std::string& sole_obn = var_node->op().SoleObn(); + const NdSbp& var_nd_sbp = var_node->NdSbp4BnInOp(sole_obn); + const Shape& logical_shape = Shape(new_var_op_conf.variable_conf().shape()); + + std::string new_split_signature = ""; + int64_t split_dim = 0; + GenerateSplitSignature(var_nd_sbp, new_var_op_conf, new_split_signature, split_dim); + if (new_split_signature != "") { + *new_var_op_conf.mutable_variable_conf()->mutable_nd_sbp(split_dim) = new_split_signature; } else { - return IsS0SignatureSupported(n); + continue; + } + + bool split_is_allowed = true; + if (split_is_allowed) { + NdSbp new_nd_sbp; + std::vector nd_sbp_str_vec; + for (const auto& sbp_str : new_var_op_conf.variable_conf().nd_sbp()) { + nd_sbp_str_vec.push_back(sbp_str); + } + ParseNdSbpFromStringList(nd_sbp_str_vec, &new_nd_sbp); + // check allowed by min shard size and evenly split + if (split_is_allowed) { + split_is_allowed = IsSplitValid(logical_shape, new_nd_sbp, *pd.hierarchy(), threshold); + } + if (split_is_allowed) { + // resize sequence by new nd sbp limit + auto& cur_seq = sorted_sequences.at(i); + int64_t max_len = 1; + if (cur_seq->len() > 1) { + FOR_RANGE(int64_t, node_idx, 1, cur_seq->len()) { + if (IsNdSbpSupported4Op(cur_seq->nodes().at(node_idx), new_nd_sbp)) { + ++max_len; + } else { + break; + } + } + } + if (max_len < cur_seq->len()) { cur_seq->resize(max_len); } + } } + if (!split_is_allowed) { + VLOG(3) << var_node->op().op_name() << " failed to change form B to S " + << " with op conf " << new_var_op_conf.variable_conf().DebugString(); + continue; + } + if (i != 0) { new_var_op_conf.add_ctrl_in_op_name(prev_allowed_op_name); } + builder->MutOpsOnlyOnce({new_var_op_conf}); + // Set consumers to consum this variable op's cast op's output as Broadcast. + if (new_split_signature != "") { + SetNdSbp4Consumers(builder, sorted_sequences.at(i), var_nd_sbp); + } + prev_allowed_op_name = var_node->op().op_name(); + VLOG(3) << var_node->op().op_name() << " succeed to change form B to " << new_split_signature + << " on ranks dim " << split_dim << " with op conf " + << new_var_op_conf.variable_conf().DebugString(); + } +} +} // namespace + +Maybe RewriteDistributedSplit(const OpGraph& op_graph, JobBuilder* builder) { + const int64_t threshold = builder->job().job_conf().optimizer_placement_optimization_threshold(); + const auto IsAllowed = [](const OpNode* n) -> bool { + // No need to limit here. + return true; }; const auto PlacementSequencesAsSplitParallel = [&](const ParallelDesc& pd, std::vector&& sorted_sequences) { - // For all sorted sequnence, set the variable op in the sequence to S(0) - // and add ctrl edge to control the exectuion order between variable ops. - // A sequence is a variable op and its cast(fp32 to fp16) op. This is because the forward pass - // consume the fp16 variable and the optimizer consume the fp32 variable. - for (int64_t i = 0; i < sorted_sequences.size(); ++i) { - const OpNode* var_node = sorted_sequences.at(i)->GetVariableNode(); - OperatorConf new_var_op_conf = var_node->op().op_conf(); - CHECK_EQ(pd.hierarchy()->NumAxes(), 1); - new_var_op_conf.mutable_variable_conf()->clear_nd_sbp(); - *new_var_op_conf.mutable_variable_conf()->add_nd_sbp() = "S(0)"; - if (i != 0) { - const std::string& prev_op_name = - sorted_sequences.at(i - 1)->GetVariableNode()->op().op_name(); - new_var_op_conf.add_ctrl_in_op_name(prev_op_name); - } - builder->MutOpsOnlyOnce({new_var_op_conf}); - // Set consumers to consum this variable op's cast op's output as Broadcast. - SetBroadcastParallel4Consumers(builder, sorted_sequences.at(i)); - } + ShardSequence(builder, threshold, pd, std::forward>(sorted_sequences)); }; ForEachParallelSortedNodeSequence(op_graph, IsAllowed, SequenceCompSortedByOrderAsc, PlacementSequencesAsSplitParallel); + JUST(builder->MutOpTransactionCommit()); return Maybe::Ok(); } @@ -313,7 +530,8 @@ class OptimizerPlacementOptimizationPass final : public JobPass { Maybe Apply(Job* job, JobPassCtx* ctx) const override { if (!(ctx->job_desc().IsTrain() - && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode())) { + && ctx->job_desc().job_conf().has_optimizer_placement_optimization_mode() + && ctx->job_desc().job_conf().optimizer_placement_optimization_mode() != "none")) { return Maybe::Ok(); } const std::string& mode = ctx->job_desc().job_conf().optimizer_placement_optimization_mode(); diff --git a/oneflow/user/ops/stack_op.cpp b/oneflow/user/ops/stack_op.cpp index 254cbcd1743..1dd129081bd 100644 --- a/oneflow/user/ops/stack_op.cpp +++ b/oneflow/user/ops/stack_op.cpp @@ -144,8 +144,6 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& /*static*/ Maybe StackGradOp::GetSbp(user_op::SbpContext* ctx) { const auto axis = ctx->Attr("axis"); - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); const int64_t like_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); FOR_RANGE(int64_t, i, 0, like_num_axes) { diff --git a/python/oneflow/framework/multi_client_session.py b/python/oneflow/framework/multi_client_session.py index 72c6e093779..64a82c12b27 100644 --- a/python/oneflow/framework/multi_client_session.py +++ b/python/oneflow/framework/multi_client_session.py @@ -124,4 +124,7 @@ def update_resource_eagerly(self, resource_config): self._session_ctx.update_resource(config_proto_str) def __del__(self): + if self._env.is_shutting_down(): + # After python shutting down, it's not safe to call oneflow + return self._TryClose() diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index dfb3795b3ac..ea48ad8d957 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -17,6 +17,7 @@ from collections import OrderedDict +import oneflow.boxing.nccl as nccl_config from oneflow.nn.graph.optimizer import OptDict import oneflow.core.job.job_conf_pb2 as job_conf_pb @@ -45,24 +46,51 @@ def training(self): return False raise NotImplementedError - def set_outputs_buffer_size(self, value: int = 2): - r"""Set the outputs buffer size of ``nn.Graph``. + def enable_amp(self, mode: bool = True): + r"""If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training. - When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time. + For example: - The default outputs buffer size is 2. + .. code-block:: python - # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code. - # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap. - # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time. + import oneflow as flow + + class Graph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear = flow.nn.Linear(3, 8, False) + self.config.enable_amp(True) # Use mixed precision mode. + def build(self, x): + return self.linear(x) + + graph = Graph() Args: - value (int): graph ouputs buffer size. + mode (bool, optional): The default vaule is True. + """ - self._outputs_buffer_size = value + assert type(mode) is bool + self.proto.enable_auto_mixed_precision = mode - def enable_amp(self, mode: bool = True): - r"""If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training. + def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): + raise RuntimeError( + "`set_zero_redundancy_optimizer_mode` has been changed to `enable_zero`, please use `enable_zero(True)` to activate ZeRO optimization." + ) + + def enable_zero( + self, + mode: bool = True, + *, + stage: int = 2, + shard_min_size: int = 1024, + shard_restore_level: int = 1, + ): + r"""Enable ZeRO redundancy optimizer. + + This optimzation will reduce optimizer states memory consumption as described + by ZeRO https://arxiv.org/abs/1910.02054 . + + The default zero stage is 2. For example: @@ -74,17 +102,36 @@ class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) - self.config.enable_amp(True) # Use mixed precision mode. + self.config.enable_zero() def build(self, x): return self.linear(x) graph = Graph() Args: - mode (bool, optional): The default vaule is True. + mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. + stage (int): optimization stage, range from 1 to 3. + shard_min_size (int): min size of a shard of an optimizer state. + shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage. """ - assert type(mode) is bool - self.proto.enable_auto_mixed_precision = mode + if not mode: + self.proto.optimizer_placement_optimization_mode = "none" + return + assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." + assert ( + shard_min_size > 0 + ), "ZeRO min size of a sharded optimizer state must > 0." + assert stage >= 1 and stage <= 3, "ZeRO stage must range form 1 to 3." + if stage >= 1: + self.proto.optimizer_placement_optimization_mode = "distributed_split" + self.proto.optimizer_placement_optimization_threshold = shard_min_size + self.proto.optimizer_placement_optimization_shard_restore_level = ( + shard_restore_level + ) + if stage >= 2: + nccl_config.enable_use_compute_stream(True) + if stage >= 3: + nccl_config.disable_group_boxing_by_dst_parallel(True) def allow_fuse_model_update_ops(self, mode: bool = True): r"""If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance. @@ -188,61 +235,23 @@ def build(self, x): """ self.proto.num_gradient_accumulation_steps = value - def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"): - r"""Set mode to remove redundancy of optimizer states. - This optimzation will reduce optimizer states memory consumption as described - by ZeRO https://arxiv.org/abs/1910.02054 . - - For example: - - .. code-block:: python - - import oneflow as flow - - class Graph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.linear = flow.nn.Linear(3, 8, False) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - def build(self, x): - return self.linear(x) - - graph = Graph() - - Args: - mode (str): "distributed_split" or "non_distributed". "distributed_split" mode - will shard each optimizer state across devices. "non_distributed" mode - will place each optimizer state to only one device. - """ - assert mode in ("distributed_split", "non_distributed") - self.proto.optimizer_placement_optimization_mode = mode - - def set_zero_redundancy_optimizer_min_size_after_split(self, value): - r"""Set the min size of optimizer state/grad/parameter after split. - - For example: - - .. code-block:: python + def set_outputs_buffer_size(self, value: int = 2): + r"""Set the outputs buffer size of ``nn.Graph``. - import oneflow as flow + When graph's outputs buffer size is greater than 2, multiple call on the graph can work like a pipeline. This makes multiple call takes less time. - class Graph(flow.nn.Graph): - def __init__(self): - super().__init__() - self.linear = flow.nn.Linear(3, 8, False) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - def build(self, x): - return self.linear(x) + The default outputs buffer size is 2. - graph = Graph() + # TODO (lixiang): Explain the meaning of the size of buffer size and add sample code. + # The size of the buffer size indicates the maximum number of iterations that the output of the Graph and the Graph actually executed asynchronously can overlap. + # If the buffer size is 1, there is no pipeline. A size of 2 means that it can execute 1 iter ahead of time. A size of 3 means that two iters can be executed ahead of time. Args: - value (int): min size value. + value (int): graph ouputs buffer size. """ assert isinstance(value, int) assert value >= 1 - self.proto.optimizer_placement_optimization_threshold = value + self._outputs_buffer_size = value def enable_cudnn_conv_heuristic_search_algo(self, mode: bool = True): r""" Whether enable cudnn conv operatioin to use heuristic search algorithm. diff --git a/python/oneflow/test/graph/test_graph_zero.py b/python/oneflow/test/graph/test_graph_zero.py index 20fc7366bab..51fa38a8657 100644 --- a/python/oneflow/test/graph/test_graph_zero.py +++ b/python/oneflow/test/graph/test_graph_zero.py @@ -26,40 +26,42 @@ def train_with_graph(iter_num=1): P = flow.placement("cuda", ranks=[0, 1]) B = flow.sbp.broadcast S0 = flow.sbp.split(0) - linear = flow.nn.Linear(8, 4) - linear = linear.to_global(placement=P, sbp=B) - flow.nn.init.constant_(linear.weight, 2.068758) - flow.nn.init.constant_(linear.bias, 0.23) - of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9) + + linear_dp = flow.nn.Linear(800, 400, bias=False) + linear_dp = linear_dp.to_global(placement=P, sbp=B) + flow.nn.init.constant_(linear_dp.weight, 2.068758) + + linear_mp = flow.nn.Linear(400, 500, bias=False) + linear_mp = linear_mp.to_global(placement=P, sbp=S0) + flow.nn.init.constant_(linear_mp.weight, 2.068758) + + of_sgd = flow.optim.SGD( + [{"params": linear_dp.parameters()}, {"params": linear_mp.parameters()}], + lr=0.001, + momentum=0.9, + ) grad_scaler = flow.amp.StaticGradScaler(200) - x = flow.randint(1, 100, (4, 8), dtype=flow.float32, placement=P, sbp=S0) + x = flow.randint(1, 100, (6, 800), dtype=flow.float32, placement=P, sbp=S0) class LinearTrainGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear = linear + self.linear_dp = linear_dp + self.linear_mp = linear_mp self.add_optimizer(of_sgd) self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - if zero_stage == 1: - print("zero stage 1 optimization") - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - if zero_stage == 2: - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - flow.boxing.nccl.enable_use_compute_stream(True) - if zero_stage == 3: - print("zero stage 3 optimization") - self.config.set_zero_redundancy_optimizer_mode("distributed_split") - self.config.set_zero_redundancy_optimizer_min_size_after_split(1) - flow.boxing.nccl.enable_use_compute_stream(True) - flow.boxing.nccl.disable_group_boxing_by_dst_parallel(True) + self.config.enable_zero( + True, stage=zero_stage, shard_min_size=1, shard_restore_level=0, + ) + self.debug(2) def build(self, x): - out = self.linear(x) + out = self.linear_dp(x) + out = out.to_global(placement=P, sbp=B) + out = self.linear_mp(out) loss = out.sum() loss.backward() return out @@ -67,19 +69,26 @@ def build(self, x): class LinearEvalGraphWithZeRO(flow.nn.Graph): def __init__(self): super().__init__() - self.linear = linear + self.linear_dp = linear_dp + self.linear_mp = linear_mp self.config.enable_amp(True) def build(self, x): - out = self.linear(x) + out = self.linear_dp(x) + out = out.to_global(placement=P, sbp=B) + out = self.linear_mp(out) return out linear_t_g = LinearTrainGraphWithZeRO() + linear_t_g.debug(1) linear_e_g = LinearEvalGraphWithZeRO() + linear_e_g.debug(1) def one_train_iter(): out = linear_t_g(x) + if flow.env.get_rank() == 0: + print(linear_t_g) def one_eval_iter(): out = linear_e_g(x) @@ -89,8 +98,116 @@ def one_eval_iter(): # After pass rewrite in training graph, parameters' sbp has been # changed from flow.sbp.broadcast to flow.sbp.split(0) - test_case.assertEqual(linear.weight.sbp[0], S0) - test_case.assertEqual(linear.bias.sbp[0], S0) + test_case.assertEqual(linear_dp.weight.sbp[0], S0) + test_case.assertEqual(linear_mp.weight.sbp[0], S0) + + # In evaluation graph, paramters's sbp are flow.sbp.split(0). + # But their consumer will consum them as flow.sbp.broadcast. + one_eval_iter() + + iter_num = 1 + graph_check_list = train_with_graph(iter_num) + + +def _test_linear_train_graph_2d_with_zero(test_case, zero_stage=1): + def train_with_graph(iter_num=1): + P = flow.placement("cuda", ranks=[[0, 1], [2, 3]]) + B = flow.sbp.broadcast + S0 = flow.sbp.split(0) + S1 = flow.sbp.split(1) + + def get_mixed_linear(): + linear_dp_mp = flow.nn.Linear(800, 400, bias=False) + linear_dp_mp = linear_dp_mp.to_global(placement=P, sbp=[B, S0]) + flow.nn.init.constant_(linear_dp_mp.weight, 1.068758) + + linear_mp_dp = flow.nn.Linear(800, 400, bias=False) + linear_mp_dp = linear_mp_dp.to_global(placement=P, sbp=[S0, B]) + flow.nn.init.constant_(linear_mp_dp.weight, 1.068758) + + class MixedLinear(flow.nn.Module): + def __init__(self): + super().__init__() + self.dp_mp = linear_dp_mp + self.mp_dp = linear_mp_dp + + def forward(self, x): + x = self.dp_mp(x) + x = flow.relu(x) + x = self.mp_dp(x) + x = flow.relu(x) + return x + + return MixedLinear() + + mixed_linear0 = get_mixed_linear() + mixed_linear1 = get_mixed_linear() + + of_sgd = flow.optim.SGD( + [ + {"params": mixed_linear0.parameters()}, + {"params": mixed_linear1.parameters()}, + ], + lr=0.001, + momentum=0.9, + ) + grad_scaler = flow.amp.StaticGradScaler(200) + + x = flow.rand((2, 800), dtype=flow.float32, placement=P, sbp=[S0, B]) + + class LinearTrainGraph2DWithZeRO(flow.nn.Graph): + def __init__(self): + super().__init__() + self.mixed_linear0 = mixed_linear0 + self.mixed_linear0.config.activation_checkpointing = True + self.mixed_linear1 = mixed_linear1 + self.mixed_linear1.config.activation_checkpointing = True + self.add_optimizer(of_sgd) + + self.config.enable_amp(True) + self.set_grad_scaler(grad_scaler) + self.config.enable_zero( + True, stage=zero_stage, shard_min_size=1, shard_restore_level=1, + ) + + def build(self, x): + out = self.mixed_linear0(x) + out = self.mixed_linear1(out) + loss = out.mean() + loss.backward() + return loss + + class LinearEvalGraph2DWithZeRO(flow.nn.Graph): + def __init__(self): + super().__init__() + self.mixed_linear0 = mixed_linear0 + self.mixed_linear1 = mixed_linear1 + + self.config.enable_amp(True) + + def build(self, x): + out = self.mixed_linear0(x) + out = self.mixed_linear1(out) + return out + + linear_t_g = LinearTrainGraph2DWithZeRO() + linear_e_g = LinearEvalGraph2DWithZeRO() + + def one_train_iter(): + out = linear_t_g(x) + # if flow.env.get_rank() == 0: + # print(linear_t_g) + + def one_eval_iter(): + out = linear_e_g(x) + + for i in range(iter_num): + one_train_iter() + + for state in linear_t_g._state(): + test_case.assertEqual( + state.origin.sbp, (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=0)) + ) # In evaluation graph, paramters's sbp are flow.sbp.split(0). # But their consumer will consum them as flow.sbp.broadcast. @@ -113,5 +230,18 @@ def test_linear_train_graph_with_zero_3(test_case): _test_linear_train_graph_with_zero(test_case, 3) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n4d() +class TestLinearTrainGraph2DWithZeRO(oneflow.unittest.TestCase): + def test_linear_train_graph_2d_with_zero_3(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 3) + + def test_linear_train_graph_2d_with_zero_2(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 2) + + def test_linear_train_graph_2d_with_zero_1(test_case): + _test_linear_train_graph_2d_with_zero(test_case, 1) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/graph/test_optimization_conf.py b/python/oneflow/test/graph/test_optimization_conf.py index da6348b7033..a60d339be8b 100644 --- a/python/oneflow/test/graph/test_optimization_conf.py +++ b/python/oneflow/test/graph/test_optimization_conf.py @@ -66,7 +66,7 @@ def __init__(self): self.config.allow_fuse_add_to_output(True) self.config.allow_fuse_cast_scale(True) self.config.set_gradient_accumulation_steps(100) - self.config.set_zero_redundancy_optimizer_mode("distributed_split") + self.config.enable_zero(True) self.config.enable_cudnn_conv_heuristic_search_algo(False) def build(self, x):