diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index ef81ebd27c8..e8220abbb73 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -97,7 +97,7 @@ const std::vector& NNGraph::outputs_tensor_meta_str() const { return outputs_tensor_meta_str_; } -int64_t NNGraph::variable_op_size() const { return variable_op_name2eager_blob_.size(); } +int64_t NNGraph::variable_op_size() const { return variable_op_names_.size(); } Maybe NNGraph::RegisterInputOpNamesAndTensors( const std::vector& input_op_names, @@ -142,21 +142,12 @@ Maybe NNGraph::RegisterVariableOpNamesAndTensors( const std::vector>& variable_tensors) { JUST(vm::CurrentRankSync()); CHECK_EQ_OR_RETURN(variable_op_names.size(), variable_tensors.size()); - CHECK_OR_RETURN(variable_op_name2eager_blob_.empty()); for (int32_t i = 0; i < variable_op_names.size(); ++i) { const std::shared_ptr& var = variable_tensors.at(i); CHECK_OR_RETURN(var->is_eager()); - Blob* var_blob = nullptr; - if (var->is_consistent()) { - // NOTE(chengcheng): var_blob maybe nullptr when consistent tensor has no cur_rank_phy_tensor. - const std::shared_ptr local_var = JUST(var->cur_rank_phy_tensor()); - var_blob = JUST(local_var->eager_blob_object())->mut_blob(); - } else { - var_blob = JUST(var->eager_blob_object())->mut_blob(); - } const std::string& var_name = variable_op_names.at(i); CHECK_OR_RETURN(!var_name.empty()); - CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second); + CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second); CHECK_OR_RETURN(variable_op_names_.insert(var_name).second); } return Maybe::Ok(); @@ -170,15 +161,8 @@ Maybe NNGraph::RegisterFreeEagerTensorsToVariableOpNames() { const std::string& var_name = pair.first; const std::shared_ptr& var = pair.second; CHECK_OR_RETURN(var->is_eager()); - Blob* var_blob = nullptr; - if (var->is_consistent()) { - const std::shared_ptr local_var = JUST(var->cur_rank_phy_tensor()); - var_blob = JUST(local_var->eager_blob_object())->mut_blob(); - } else { - var_blob = JUST(var->eager_blob_object())->mut_blob(); - } CHECK_OR_RETURN(!var_name.empty()); - CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second); + CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, var).second); CHECK_OR_RETURN(variable_op_names_.insert(var_name).second); } return Maybe::Ok(); @@ -218,10 +202,7 @@ Maybe NNGraph::CreateAndRegisterNewVariableOpInJobPass() { } std::shared_ptr tensor = JUST(one::functional::ConsistentConstant( blob_desc.shape(), value, Symbol(dtype), placement, *sbp_tuple)); - JUST(vm::CurrentRankSync()); - const std::shared_ptr local_var = JUST(tensor->cur_rank_phy_tensor()); - Blob* var_blob = JUST(local_var->eager_blob_object())->mut_blob(); - CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second); + CHECK_OR_RETURN(variable_op_name2tensor_.emplace(var_name, tensor).second); CHECK_OR_RETURN(variable_op_names_.insert(var_name).second); // NOTE(chengcheng): just for tensor lifetime hold by session context in graph lifetime valid. @@ -292,11 +273,56 @@ Maybe NNGraph::CompileAndInitRuntime() { PlanUtil::PopulateOpAttribute(&plan_, plan_.job_id2op_attribute_ref_table()); NewRuntimeBuffers(); + + JUST(DumpToVariable2BlobMap()); runtime_.reset(new Runtime(plan_, variable_op_name2eager_blob_)); runtime_inited_ = true; return Maybe::Ok(); } +Maybe NNGraph::DumpToVariable2BlobMap() { + CHECK_OR_RETURN(variable_op_name2eager_blob_.empty()); + JUST(vm::CurrentRankSync()); + for (const std::string& var_name : variable_op_names_) { + auto iter = variable_op_name2tensor_.find(var_name); + CHECK_OR_RETURN(iter != variable_op_name2tensor_.end()) << var_name << " not found."; + std::shared_ptr tensor = iter->second; + Blob* var_blob = nullptr; + // Check sbp whether same or not and allow auto_parallel to change sbp. + if (tensor->is_consistent()) { + cfg::NdSbpSignature nd_sbp_signature = cfg::NdSbpSignature( + job_.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(var_name)); + cfg::NdSbp job_nd_sbp = nd_sbp_signature.bn_in_op2nd_sbp().at("out"); + if (*JUST(tensor->nd_sbp()) == job_nd_sbp) { + const std::shared_ptr local_var = JUST(tensor->cur_rank_phy_tensor()); + var_blob = JUST(local_var->eager_blob_object())->mut_blob(); + } else { + CHECK_OR_RETURN(job_.job_conf().enable_auto_parallel()) + << "Only change sbp when enable auto_parallel."; + LOG(INFO) << "Variable with name `" << var_name << "` change sbp from " + << NdSbpParallelToString(*JUST(tensor->nd_sbp())) << " to " + << NdSbpParallelToString(job_nd_sbp); + std::vector> job_sbp_parallels; + for (int i = 0; i < job_nd_sbp.sbp_parallel_size(); ++i) { + job_sbp_parallels.emplace_back(job_nd_sbp.sbp_parallel(i)); + } + const auto& new_tensor = JUST(one::functional::ToConsistent( + tensor, JUST(tensor->parallel_desc()), job_sbp_parallels, {})); + JUST(vm::CurrentRankSync()); + // Use tensor.set_data inferface and make new TensorImpl instead of the old one. + variable_op_name2tensor_.at(var_name)->set_data(new_tensor); + const std::shared_ptr local_var = + JUST(new_tensor->cur_rank_phy_tensor()); + var_blob = JUST(local_var->eager_blob_object())->mut_blob(); + } + } else { + var_blob = JUST(tensor->eager_blob_object())->mut_blob(); + } + CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second); + } + return Maybe::Ok(); +} + void NNGraph::NewRuntimeBuffers() { auto* buffer_mgr = Global>>::Get(); // NOTE(chengcheng): diff --git a/oneflow/core/framework/nn_graph.h b/oneflow/core/framework/nn_graph.h index 6e3ac26ce3d..03509837a3d 100644 --- a/oneflow/core/framework/nn_graph.h +++ b/oneflow/core/framework/nn_graph.h @@ -57,6 +57,7 @@ class NNGraph final : public NNGraphIf { private: Maybe RegisterFreeEagerTensorsToVariableOpNames(); Maybe CreateAndRegisterNewVariableOpInJobPass(); + Maybe DumpToVariable2BlobMap(); void NewRuntimeBuffers(); void CloseRuntimeBuffers(); @@ -68,6 +69,7 @@ class NNGraph final : public NNGraphIf { std::vector output_tensors_valid_; std::vector inputs_tensor_meta_str_; std::vector outputs_tensor_meta_str_; + HashMap> variable_op_name2tensor_; HashMap variable_op_name2eager_blob_; HashSet variable_op_names_; Job job_;