diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index f6547053970..bead3d261b2 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -63,10 +63,57 @@ std::vector GetCompTaskNodesOnEdge( return comp_task_nodes; } +std::shared_ptr NewFakeDataRegstDesc() { + auto regst_desc = std::make_shared(); + regst_desc->mut_regst_desc_type()->mutable_data_regst_desc(); + return regst_desc; +} + } // namespace +void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) { + ConsumeRegst(regst_name, NewFakeDataRegstDesc()); + fake_consumed_regst_names_.insert(regst_name); +} + +void CompTaskNode::ConsumeFakeRegstsIf() { + ConsumeFakeRegsts(); + RegstDesc* data_regst_desc = nullptr; + for (const auto& pair : consumed_regsts()) { + for (const auto& regst_desc : pair.second) { + if (regst_desc->regst_desc_type().has_data_regst_desc()) { + // Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts(). + CHECK(data_regst_desc == nullptr); + data_regst_desc = CHECK_NOTNULL(regst_desc.get()); + } else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) { + // do nothing. + } else { + UNIMPLEMENTED(); + } + } + } + if (data_regst_desc != nullptr) { + for (const auto& ibn : op_node()->op().input_bns()) { + // Only one fake data regst is creatd and just use it for all input_bns as a placeholder. + data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); + } + } +} + +void CompTaskNode::EraseFakeRegstsIf() { + for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) { + EraseConsumedRegstsByName(fake_consumed_regst_name); + } + fake_consumed_regst_names_.clear(); +} + std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } +void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) { + TaskNode::InitFromProtoExceptConsumedRegsts(proto); + parallel_ctx_ = proto.parallel_ctx(); +} + void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const { TaskNode::ToProto(task_proto, check); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index bba6d957668..7d6cb03bc47 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -18,17 +18,25 @@ limitations under the License. #include "oneflow/core/graph/task_node.h" #include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/graph/fake_consumed_regst_provider.h" #include "oneflow/core/device/cuda_util.h" namespace oneflow { -class CompTaskNode : public TaskNode { +class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { public: OF_DISALLOW_COPY_AND_MOVE(CompTaskNode); CompTaskNode() = default; virtual ~CompTaskNode() = default; virtual void ToProto(TaskProto*, bool check) const override; + virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override; + void ConsumeFakeRegstsIf() override; + void EraseFakeRegstsIf() override; + + // ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. + virtual void ConsumeFakeRegsts() = 0; + void ConsumeFakeRegst(const std::string& regst_name); // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } @@ -58,6 +66,7 @@ class CompTaskNode : public TaskNode { private: ParallelContext parallel_ctx_; const OpNode* op_node_; + HashSet fake_consumed_regst_names_; }; class OpCompTaskNodeCreator { diff --git a/oneflow/core/graph/fake_consumed_regst_provider.h b/oneflow/core/graph/fake_consumed_regst_provider.h new file mode 100644 index 00000000000..2de5a194924 --- /dev/null +++ b/oneflow/core/graph/fake_consumed_regst_provider.h @@ -0,0 +1,35 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ +#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ + +namespace oneflow { + +// Provide a compute task node with a fake input regst, and its output regst can be inferred using +// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob +// desc, mainly to ensure that the transport task node has the correct input blob desc. +class FakeConsumedRegstProvider { + public: + FakeConsumedRegstProvider() = default; + virtual ~FakeConsumedRegstProvider() = default; + + virtual void ConsumeFakeRegstsIf() = 0; + virtual void EraseFakeRegstsIf() = 0; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ diff --git a/oneflow/core/graph/normal_forward_compute_task_node.h b/oneflow/core/graph/normal_forward_compute_task_node.h index 53d1807c5a7..b3da9f36edb 100644 --- a/oneflow/core/graph/normal_forward_compute_task_node.h +++ b/oneflow/core/graph/normal_forward_compute_task_node.h @@ -30,6 +30,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kNormalForward; } diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index ed25fa43cc4..affa2818481 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/graph/task_node.h" +#include "oneflow/core/common/container_util.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/memory/memory_case_util.h" #include "oneflow/core/graph/task_graph_rebuild_ctx.h" @@ -203,6 +204,38 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } +void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) { + // Step1: init some scalar items. + CHECK(task_proto.task_type() == GetTaskType()); + machine_id_ = task_proto.machine_id(); + thrd_id_ = task_proto.thrd_id(); + task_id_ = task_proto.task_id(); + new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_))); + CHECK(task_proto.job_id() == GlobalJobDesc().job_id()); + chain_id_ = task_proto.chain_id(); + order_in_chain_ = task_proto.order_in_chain(); + // Step2: check exec_gph empty. + CHECK(task_proto.exec_sequence().exec_node().empty()); + // Step3: init produced_regst. + for (const auto& pair : task_proto.produced_regst_desc()) { + const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem()); + // regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto. + regst_desc->InitFromProtoExceptConsumers(pair.second); + } +} + +Maybe TaskNode::InitConsumedRegstsFromProto( + const TaskProto& task_proto, + const std::function(int64_t regst_desc_id)>& RegstDesc4Id) { + // init consumed_regst. + for (const auto& pair : task_proto.consumed_regst_desc_id()) { + for (int64_t regst_desc_id : pair.second.regst_desc_id()) { + ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id))); + } + } + return Maybe::Ok(); +} + void TaskNode::ToProto(TaskProto* task_proto, bool check) const { // Step1: process some scalar items. task_proto->set_task_type(GetTaskType()); @@ -265,9 +298,23 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) { } void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { + if (edge->HasRegst(name)) { return; } edge->AddRegst(name, GetProducedRegst(name)); } +std::shared_ptr TaskNode::GetAndCheckRegst(const std::string& name, + bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const { + auto iter = produced_regsts_.find(name); + if (iter == produced_regsts_.end()) { return nullptr; } + const auto& regst = (iter->second); + CHECK_EQ(regst->min_register_num(), min_register_num); + CHECK_EQ(regst->max_register_num(), max_register_num); + CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem); + return regst; +} + std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) { return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum); } @@ -275,6 +322,10 @@ std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool std::shared_ptr TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, int32_t min_register_num, int32_t max_register_num) { + // Because the Regst of separate compilation is not created in order, some Regst may have been + // built. This implementation can avoid ProduceRegst being called multiple times. + const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); + if (regst) { return regst; } RegstDescTypeProto regst_desc_type; regst_desc_type.mutable_data_regst_desc(); return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type); @@ -353,8 +404,14 @@ std::shared_ptr TaskEdge::GetRegst(const std::string& name_in_produce return name_in_producer2regst_.at(name_in_producer); } +bool TaskEdge::HasRegst(const std::string& name_in_producer) const { + return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end()); +} + std::shared_ptr TaskEdge::GetSoleRegst() const { - CHECK_EQ(name_in_producer2regst_.size(), 1); + CHECK_EQ(name_in_producer2regst_.size(), 1) + << "edge: " << this << ", src: " << src_node()->task_id() + << ", dst: " << dst_node()->task_id(); return name_in_producer2regst_.begin()->second; } @@ -367,6 +424,11 @@ std::vector> TaskEdge::GetRegsts() const { void TaskEdge::AddRegst(const std::string& name_in_producer, const std::shared_ptr& regst) { + if (HasRegst(name_in_producer)) { + CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id() + == regst->regst_desc_id()); + return; + } CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 7856c0d08d1..81937610ae1 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -100,6 +100,11 @@ class TaskNode : public Node { std::string VisualStr() const override; virtual bool IsMeaningLess(); void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } + // Used to create task node from proto in plan separation compilation. + virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); + Maybe InitConsumedRegstsFromProto( + const TaskProto& task_proto, + const std::function(int64_t regst_desc_id)>& RegstDesc4Id); virtual void ToProto(TaskProto* task_proto, bool check) const; void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual MemZoneId MemZoneId121() const; @@ -150,6 +155,9 @@ class TaskNode : public Node { private: void UpdateTaskId(); + std::shared_ptr GetAndCheckRegst(const std::string& name, bool enable_reuse_mem, + int32_t min_register_num, + int32_t max_register_num) const; int64_t machine_id_; int64_t thrd_id_; @@ -172,6 +180,7 @@ class TaskEdge final : public Edge { ~TaskEdge() override = default; std::shared_ptr GetRegst(const std::string& name_in_producer) const; + bool HasRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; std::vector> GetRegsts() const; const HashSet& GetLbis() const { return lbis_; } @@ -181,6 +190,7 @@ class TaskEdge final : public Edge { void AddLbis(const std::vector& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } void CheckRegstLbiValid() const; + bool HasRegst() const { return !name_in_producer2regst_.empty(); } Maybe InitFromProto(const TaskEdgeProto& proto, const TaskGraphRebuildCtx& task_graph_rebuild_ctx); diff --git a/oneflow/core/graph_impl/acc_compute_task_node.cpp b/oneflow/core/graph_impl/acc_compute_task_node.cpp index 17695dc3f68..016bec8cffb 100644 --- a/oneflow/core/graph_impl/acc_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_compute_task_node.cpp @@ -27,6 +27,7 @@ class AccCompTaskNode final : public CompTaskNode { void BuildExecGphAndRegst() override; void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; }; void AccCompTaskNode::ProduceAllRegstsAndBindEdges() { @@ -35,6 +36,7 @@ void AccCompTaskNode::ProduceAllRegstsAndBindEdges() { } void AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); diff --git a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp index 7de08f51154..b6dee6102bb 100644 --- a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp @@ -27,6 +27,7 @@ class AccCtrlTickCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; void BuildExecGphAndRegst() override; + void ConsumeFakeRegsts() override; }; void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() { @@ -38,6 +39,8 @@ void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); diff --git a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp index 47f93a81c2c..83b23730c49 100644 --- a/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_tick_compute_task_node.cpp @@ -26,6 +26,7 @@ class AccTickCompTaskNode final : public CompTaskNode { TaskType GetTaskType() const override { return TaskType::kAccTick; } void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -37,6 +38,7 @@ void AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() { void AccTickCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } void AccTickCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr in_regst = GetSoleConsumedRegst("in"); diff --git a/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp b/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp index d4326fc9588..80baabc6321 100644 --- a/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp +++ b/oneflow/core/graph_impl/callback_notify_compute_task_node.cpp @@ -29,6 +29,7 @@ class CallbackNotifyCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -38,6 +39,8 @@ void CallbackNotifyCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CallbackNotifyCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = this->op(); diff --git a/oneflow/core/graph_impl/case_compute_task_node.cpp b/oneflow/core/graph_impl/case_compute_task_node.cpp index 803b5271cc7..fc5cc90b52b 100644 --- a/oneflow/core/graph_impl/case_compute_task_node.cpp +++ b/oneflow/core/graph_impl/case_compute_task_node.cpp @@ -26,6 +26,7 @@ class CaseCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kCase; } @@ -36,6 +37,8 @@ class CaseCompTaskNode final : public CompTaskNode { void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() { HashMap lbi2obn_id; FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) { diff --git a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp index d2a4a6775b6..65341724df9 100644 --- a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp +++ b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp @@ -30,6 +30,7 @@ class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp index 5d7f94da6d2..c548272c4cb 100644 --- a/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp +++ b/oneflow/core/graph_impl/decode_h2d_compute_task_node.cpp @@ -27,6 +27,7 @@ class DecodeH2DCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDecodeH2D; } @@ -38,6 +39,8 @@ void DecodeH2DCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void DecodeH2DCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() { auto regst_num = ParseIntegerFromEnv("ONEFLOW_DECODE_H2D_REGST_NUM", 2); std::shared_ptr out_regst = ProduceRegst("out", false, regst_num, regst_num); diff --git a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp index 1a2a53c0b50..1ffcfdcca95 100644 --- a/oneflow/core/graph_impl/device_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/device_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class DeviceTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void DeviceTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DeviceTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DeviceTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp index bc8a9ace7f3..5c3b5e9b069 100644 --- a/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_concat_compute_task_node.cpp @@ -26,6 +26,7 @@ class DistributeConcatCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeConcat; } @@ -49,6 +50,8 @@ void DistributeConcatCompTaskNode::ConsumeAllRegsts() { CHECK_EQ(cnt, 1); } +void DistributeConcatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DistributeConcatCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp index fd3e0334888..f387719a17f 100644 --- a/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp +++ b/oneflow/core/graph_impl/distribute_split_compute_task_node.cpp @@ -26,6 +26,7 @@ class DistributeSplitCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kDistributeSplit; } @@ -44,6 +45,8 @@ void DistributeSplitCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DistributeSplitCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DistributeSplitCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp index 901a4b10092..e587ced5e02 100644 --- a/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/dst_subset_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class DstSubsetTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void DstSubsetTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void DstSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void DstSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/esac_compute_task_node.cpp b/oneflow/core/graph_impl/esac_compute_task_node.cpp index 8595c22ce0f..ac43f5a2c91 100644 --- a/oneflow/core/graph_impl/esac_compute_task_node.cpp +++ b/oneflow/core/graph_impl/esac_compute_task_node.cpp @@ -26,6 +26,7 @@ class EsacCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override { UNIMPLEMENTED() << "EsacCompTaskNode is deprecated"; } TaskType GetTaskType() const override { return TaskType::kEsac; } diff --git a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp index 8f71a3ac335..a93742b76a2 100644 --- a/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp +++ b/oneflow/core/graph_impl/normal_forward_compute_task_node.cpp @@ -95,6 +95,8 @@ void NormalForwardCompTaskNode::ConsumeAllRegsts() { }); } +void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void NormalForwardCompTaskNode::BuildExecGphAndRegst() { BuildExecGphStructAndBindInRegst(); BuildOutRegst(); diff --git a/oneflow/core/graph_impl/pack_compute_task_node.cpp b/oneflow/core/graph_impl/pack_compute_task_node.cpp index 88ec0ce6834..88f27c6527b 100644 --- a/oneflow/core/graph_impl/pack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/pack_compute_task_node.cpp @@ -28,6 +28,7 @@ class PackCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; @@ -40,6 +41,8 @@ void PackCompTaskNode::ProduceAllRegstsAndBindEdges() { void PackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void PackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void PackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp index 5045b65dc23..8b0cfef852b 100644 --- a/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp +++ b/oneflow/core/graph_impl/reentrant_lock_compute_task_node.cpp @@ -30,6 +30,7 @@ class ReentrantLockCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; void InferProducedDataRegstTimeShape() override; }; @@ -44,6 +45,8 @@ void ReentrantLockCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void ReentrantLockCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void ReentrantLockCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/repeat_compute_task_node.cpp b/oneflow/core/graph_impl/repeat_compute_task_node.cpp index 6385ac8dc37..582d61dadef 100644 --- a/oneflow/core/graph_impl/repeat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/repeat_compute_task_node.cpp @@ -26,6 +26,7 @@ class RepeatCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; TaskType GetTaskType() const override { return TaskType::kRepeat; } @@ -37,6 +38,8 @@ void RepeatCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void RepeatCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr out_regst = ProduceRegst("out", false, 1, 1); ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); }); diff --git a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp index 93393a89145..93cbb1dda7d 100644 --- a/oneflow/core/graph_impl/source_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/source_tick_compute_task_node.cpp @@ -26,6 +26,7 @@ class SourceTickCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} + void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } diff --git a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp index e1fa7349553..30639d07abe 100644 --- a/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/src_subset_tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class SrcSubsetTickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void SrcSubsetTickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void SrcSubsetTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void SrcSubsetTickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp index aa6d46729b0..4c8e2743ab8 100644 --- a/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp +++ b/oneflow/core/graph_impl/ssp_variable_proxy_task_node.cpp @@ -67,6 +67,8 @@ class SspVariableProxyCompTaskNode final : public CompTaskNode { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("var", edge->GetSoleRegst()); }); } + void ConsumeFakeRegsts() override { ConsumeFakeRegst("var"); } + TaskType GetTaskType() const override { return TaskType::kSspVariableProxy; } private: diff --git a/oneflow/core/graph_impl/tick_compute_task_node.cpp b/oneflow/core/graph_impl/tick_compute_task_node.cpp index 052b72841a7..306f78874e6 100644 --- a/oneflow/core/graph_impl/tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/tick_compute_task_node.cpp @@ -30,6 +30,7 @@ class TickCompTaskNode final : public CompTaskNode { private: void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; void BuildExecGphAndRegst() override; }; @@ -43,6 +44,8 @@ void TickCompTaskNode::ConsumeAllRegsts() { ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); } +void TickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void TickCompTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/unpack_compute_task_node.cpp b/oneflow/core/graph_impl/unpack_compute_task_node.cpp index 2916ee2fbb5..49a4dfa3e2d 100644 --- a/oneflow/core/graph_impl/unpack_compute_task_node.cpp +++ b/oneflow/core/graph_impl/unpack_compute_task_node.cpp @@ -28,6 +28,7 @@ class UnpackCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override; + void ConsumeFakeRegsts() override; private: void BuildExecGphAndRegst() override; @@ -42,6 +43,8 @@ void UnpackCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); } +void UnpackCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); } + void UnpackCompTaskNode::BuildExecGphAndRegst() { ExecNode* exec_node = mut_exec_gph().NewNode(); exec_node->mut_op() = op(); diff --git a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp index ba79fca0c8d..9d3e0626b9e 100644 --- a/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp +++ b/oneflow/core/graph_impl/wait_and_send_ids_compute_task_node.cpp @@ -26,6 +26,7 @@ class WaitAndSendIdsCompTaskNode final : public CompTaskNode { void ProduceAllRegstsAndBindEdges() override; void ConsumeAllRegsts() override {} + void ConsumeFakeRegsts() override {} void BuildExecGphAndRegst() override; bool IsMeaningLess() override { return false; } diff --git a/oneflow/core/register/register_desc.cpp b/oneflow/core/register/register_desc.cpp index ca5ac440556..d7b3e75efb1 100644 --- a/oneflow/core/register/register_desc.cpp +++ b/oneflow/core/register/register_desc.cpp @@ -120,6 +120,32 @@ void RegstDesc::EraseUninitializedShapeBlob() { }); } +void RegstDesc::InitFromProtoExceptConsumers(const RegstDescProto& proto) { + regst_desc_id_ = proto.regst_desc_id(); + CHECK_EQ(proto.producer_task_id(), producer_->task_id()); + regst_desc_type_ = proto.regst_desc_type(); + if (regst_desc_type_.has_data_regst_desc()) { + const DataRegstDesc& data_regst_desc_proto = proto.regst_desc_type().data_regst_desc(); + for (const auto& pair : data_regst_desc_proto.lbi2blob_desc()) { + *AddLbi(pair.lbi()) = BlobDesc(pair.blob_desc()); + } + CHECK(!data_regst_desc_proto.has_time_shape()); + } else if (regst_desc_type_.has_ctrl_regst_desc()) { + // do nothing + } else { + UNIMPLEMENTED(); + } + min_register_num_ = proto.min_register_num(); + max_register_num_ = proto.max_register_num(); + min_register_num_ = proto.register_num(); + mem_case_ = proto.mem_case(); + enable_reuse_mem_ = proto.enable_reuse_mem(); + mem_block_id_ = proto.mem_block_id(); + mem_block_offset_ = proto.mem_block_offset(); + hint_inplace_consumed_regst_desc_id_ = proto.hint_inplace_consumed_regst_desc_id(); + force_inplace_consumed_regst_desc_id_ = proto.force_inplace_consumed_regst_desc_id(); +} + void RegstDesc::ToProto(RegstDescProto* ret, bool check) const { ret->set_regst_desc_id(regst_desc_id_); ret->set_producer_task_id(producer_->task_id()); diff --git a/oneflow/core/register/register_desc.h b/oneflow/core/register/register_desc.h index faee94ca856..1da2e267325 100644 --- a/oneflow/core/register/register_desc.h +++ b/oneflow/core/register/register_desc.h @@ -99,6 +99,7 @@ class RegstDesc final { // util void EraseUninitializedShapeBlob(); + void InitFromProtoExceptConsumers(const RegstDescProto& proto); void ToProto(RegstDescProto* proto) const { ToProto(proto, /*check*/ true); } void ToProto(RegstDescProto*, bool check) const; bool HasSameBlobDescs(const RegstDesc*); @@ -117,8 +118,8 @@ class RegstDesc final { bool enable_reuse_mem_; int32_t mem_block_id_; int64_t mem_block_offset_; - int32_t hint_inplace_consumed_regst_desc_id_; - int32_t force_inplace_consumed_regst_desc_id_; + int64_t hint_inplace_consumed_regst_desc_id_; + int64_t force_inplace_consumed_regst_desc_id_; std::shared_ptr data_regst_time_shape_; };