-
Notifications
You must be signed in to change notification settings - Fork 796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add consume fake regst #10140
Add consume fake regst #10140
Changes from all commits
62dee17
db84632
49c8d18
041a4b3
3e08944
e0cf92b
008239e
be2987d
9acbc79
34b3133
acff92c
17203d0
4377368
88f4297
0e7b8ed
266c388
9f952bd
a9ad100
95ae89f
38bf366
a3b9798
2be6d00
014a96c
277a8d9
6d7ed9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,10 +63,57 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge( | |
return comp_task_nodes; | ||
} | ||
|
||
std::shared_ptr<RegstDesc> NewFakeDataRegstDesc() { | ||
auto regst_desc = std::make_shared<RegstDesc>(); | ||
regst_desc->mut_regst_desc_type()->mutable_data_regst_desc(); | ||
return regst_desc; | ||
} | ||
|
||
} // namespace | ||
|
||
void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) { | ||
ConsumeRegst(regst_name, NewFakeDataRegstDesc()); | ||
fake_consumed_regst_names_.insert(regst_name); | ||
} | ||
|
||
void CompTaskNode::ConsumeFakeRegstsIf() { | ||
ConsumeFakeRegsts(); | ||
RegstDesc* data_regst_desc = nullptr; | ||
for (const auto& pair : consumed_regsts()) { | ||
for (const auto& regst_desc : pair.second) { | ||
if (regst_desc->regst_desc_type().has_data_regst_desc()) { | ||
// Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts(). | ||
CHECK(data_regst_desc == nullptr); | ||
data_regst_desc = CHECK_NOTNULL(regst_desc.get()); | ||
} else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) { | ||
// do nothing. | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
} | ||
} | ||
if (data_regst_desc != nullptr) { | ||
for (const auto& ibn : op_node()->op().input_bns()) { | ||
// Only one fake data regst is creatd and just use it for all input_bns as a placeholder. | ||
data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 多个 lbi 指向了这占位的 fake regst |
||
} | ||
} | ||
} | ||
|
||
void CompTaskNode::EraseFakeRegstsIf() { | ||
for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) { | ||
EraseConsumedRegstsByName(fake_consumed_regst_name); | ||
} | ||
fake_consumed_regst_names_.clear(); | ||
} | ||
|
||
std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } | ||
|
||
void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) { | ||
TaskNode::InitFromProtoExceptConsumedRegsts(proto); | ||
parallel_ctx_ = proto.parallel_ctx(); | ||
} | ||
|
||
void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const { | ||
TaskNode::ToProto(task_proto, check); | ||
*(task_proto->mutable_parallel_ctx()) = parallel_ctx_; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,17 +18,25 @@ limitations under the License. | |
|
||
#include "oneflow/core/graph/task_node.h" | ||
#include "oneflow/core/graph/op_graph.h" | ||
#include "oneflow/core/graph/fake_consumed_regst_provider.h" | ||
#include "oneflow/core/device/cuda_util.h" | ||
|
||
namespace oneflow { | ||
|
||
class CompTaskNode : public TaskNode { | ||
class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider { | ||
public: | ||
OF_DISALLOW_COPY_AND_MOVE(CompTaskNode); | ||
CompTaskNode() = default; | ||
virtual ~CompTaskNode() = default; | ||
|
||
virtual void ToProto(TaskProto*, bool check) const override; | ||
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override; | ||
void ConsumeFakeRegstsIf() override; | ||
void EraseFakeRegstsIf() override; | ||
|
||
// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks. | ||
virtual void ConsumeFakeRegsts() = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BoxingTaskGraph 中的 TransportTaskNode 会带有上游的多个 rank 的 ComputeTaskNode,当做 task graph 的 rank 内 infer BlobDesc 时,非本 rank 的 input regst 是不存在的,但是需要一个 fake 数据来保证正常的推理逻辑可以通过。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本 |
||
void ConsumeFakeRegst(const std::string& regst_name); | ||
|
||
// parallel_ctx_ | ||
int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } | ||
|
@@ -58,6 +66,7 @@ class CompTaskNode : public TaskNode { | |
private: | ||
ParallelContext parallel_ctx_; | ||
const OpNode* op_node_; | ||
HashSet<std::string> fake_consumed_regst_names_; | ||
}; | ||
|
||
class OpCompTaskNodeCreator { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ | ||
#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ | ||
|
||
namespace oneflow { | ||
|
||
// Provide a compute task node with a fake input regst, and its output regst can be inferred using | ||
// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob | ||
// desc, mainly to ensure that the transport task node has the correct input blob desc. | ||
class FakeConsumedRegstProvider { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mixin 接口类 |
||
public: | ||
FakeConsumedRegstProvider() = default; | ||
virtual ~FakeConsumedRegstProvider() = default; | ||
|
||
virtual void ConsumeFakeRegstsIf() = 0; | ||
virtual void EraseFakeRegstsIf() = 0; | ||
}; | ||
|
||
} // namespace oneflow | ||
|
||
#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and | |
limitations under the License. | ||
*/ | ||
#include "oneflow/core/graph/task_node.h" | ||
#include "oneflow/core/common/container_util.h" | ||
#include "oneflow/core/job/id_manager.h" | ||
#include "oneflow/core/memory/memory_case_util.h" | ||
#include "oneflow/core/graph/task_graph_rebuild_ctx.h" | ||
|
@@ -203,6 +204,38 @@ std::string TaskNode::VisualStr() const { | |
|
||
bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } | ||
|
||
void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) { | ||
// Step1: init some scalar items. | ||
CHECK(task_proto.task_type() == GetTaskType()); | ||
machine_id_ = task_proto.machine_id(); | ||
thrd_id_ = task_proto.thrd_id(); | ||
task_id_ = task_proto.task_id(); | ||
new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_))); | ||
CHECK(task_proto.job_id() == GlobalJobDesc().job_id()); | ||
chain_id_ = task_proto.chain_id(); | ||
order_in_chain_ = task_proto.order_in_chain(); | ||
// Step2: check exec_gph empty. | ||
CHECK(task_proto.exec_sequence().exec_node().empty()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consumed_regst_desc_id 要检查吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
检查它非空么,consumed_regst_desc_id 在后面做了消费 |
||
// Step3: init produced_regst. | ||
for (const auto& pair : task_proto.produced_regst_desc()) { | ||
const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem()); | ||
// regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto. | ||
regst_desc->InitFromProtoExceptConsumers(pair.second); | ||
} | ||
} | ||
|
||
Maybe<void> TaskNode::InitConsumedRegstsFromProto( | ||
const TaskProto& task_proto, | ||
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) { | ||
// init consumed_regst. | ||
for (const auto& pair : task_proto.consumed_regst_desc_id()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consumed_regst_desc_id 这里做了消费 |
||
for (int64_t regst_desc_id : pair.second.regst_desc_id()) { | ||
ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id))); | ||
} | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
void TaskNode::ToProto(TaskProto* task_proto, bool check) const { | ||
// Step1: process some scalar items. | ||
task_proto->set_task_type(GetTaskType()); | ||
|
@@ -265,16 +298,34 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) { | |
} | ||
|
||
void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) { | ||
if (edge->HasRegst(name)) { return; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 避免重复绑定 |
||
edge->AddRegst(name, GetProducedRegst(name)); | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskNode::GetAndCheckRegst(const std::string& name, | ||
bool enable_reuse_mem, | ||
int32_t min_register_num, | ||
int32_t max_register_num) const { | ||
auto iter = produced_regsts_.find(name); | ||
if (iter == produced_regsts_.end()) { return nullptr; } | ||
const auto& regst = (iter->second); | ||
CHECK_EQ(regst->min_register_num(), min_register_num); | ||
CHECK_EQ(regst->max_register_num(), max_register_num); | ||
CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem); | ||
return regst; | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) { | ||
return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum); | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem, | ||
int32_t min_register_num, | ||
int32_t max_register_num) { | ||
// Because the Regst of separate compilation is not created in order, some Regst may have been | ||
// built. This implementation can avoid ProduceRegst being called multiple times. | ||
const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num); | ||
if (regst) { return regst; } | ||
RegstDescTypeProto regst_desc_type; | ||
regst_desc_type.mutable_data_regst_desc(); | ||
return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type); | ||
|
@@ -353,8 +404,14 @@ std::shared_ptr<RegstDesc> TaskEdge::GetRegst(const std::string& name_in_produce | |
return name_in_producer2regst_.at(name_in_producer); | ||
} | ||
|
||
bool TaskEdge::HasRegst(const std::string& name_in_producer) const { | ||
return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end()); | ||
} | ||
|
||
std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const { | ||
CHECK_EQ(name_in_producer2regst_.size(), 1); | ||
CHECK_EQ(name_in_producer2regst_.size(), 1) | ||
<< "edge: " << this << ", src: " << src_node()->task_id() | ||
<< ", dst: " << dst_node()->task_id(); | ||
return name_in_producer2regst_.begin()->second; | ||
} | ||
|
||
|
@@ -367,6 +424,11 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const { | |
|
||
void TaskEdge::AddRegst(const std::string& name_in_producer, | ||
const std::shared_ptr<RegstDesc>& regst) { | ||
if (HasRegst(name_in_producer)) { | ||
CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id() | ||
== regst->regst_desc_id()); | ||
return; | ||
} | ||
CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,6 +100,11 @@ class TaskNode : public Node<TaskNode, TaskEdge> { | |
std::string VisualStr() const override; | ||
virtual bool IsMeaningLess(); | ||
void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); } | ||
// Used to create task node from proto in plan separation compilation. | ||
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
done |
||
Maybe<void> InitConsumedRegstsFromProto( | ||
const TaskProto& task_proto, | ||
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id); | ||
virtual void ToProto(TaskProto* task_proto, bool check) const; | ||
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); | ||
virtual MemZoneId MemZoneId121() const; | ||
|
@@ -150,6 +155,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> { | |
|
||
private: | ||
void UpdateTaskId(); | ||
std::shared_ptr<RegstDesc> GetAndCheckRegst(const std::string& name, bool enable_reuse_mem, | ||
int32_t min_register_num, | ||
int32_t max_register_num) const; | ||
|
||
int64_t machine_id_; | ||
int64_t thrd_id_; | ||
|
@@ -172,6 +180,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> { | |
~TaskEdge() override = default; | ||
|
||
std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const; | ||
bool HasRegst(const std::string& name_in_producer) const; | ||
std::shared_ptr<RegstDesc> GetSoleRegst() const; | ||
std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const; | ||
const HashSet<LogicalBlobId>& GetLbis() const { return lbis_; } | ||
|
@@ -181,6 +190,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> { | |
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); } | ||
|
||
void CheckRegstLbiValid() const; | ||
bool HasRegst() const { return !name_in_producer2regst_.empty(); } | ||
|
||
Maybe<void> InitFromProto(const TaskEdgeProto& proto, | ||
const TaskGraphRebuildCtx& task_graph_rebuild_ctx); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没看懂,表示最多只有一个 data regst ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,只创建了一个占位的 fake regst。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那加下注释吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done