Skip to content

Commit

Permalink
Add consume fake regst (#10140)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
strint and mergify[bot] authored May 6, 2023
1 parent 726904e commit 9cd5d64
Show file tree
Hide file tree
Showing 30 changed files with 251 additions and 4 deletions.
47 changes: 47 additions & 0 deletions oneflow/core/graph/compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,57 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge(
return comp_task_nodes;
}

std::shared_ptr<RegstDesc> NewFakeDataRegstDesc() {
auto regst_desc = std::make_shared<RegstDesc>();
regst_desc->mut_regst_desc_type()->mutable_data_regst_desc();
return regst_desc;
}

} // namespace

void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) {
ConsumeRegst(regst_name, NewFakeDataRegstDesc());
fake_consumed_regst_names_.insert(regst_name);
}

void CompTaskNode::ConsumeFakeRegstsIf() {
ConsumeFakeRegsts();
RegstDesc* data_regst_desc = nullptr;
for (const auto& pair : consumed_regsts()) {
for (const auto& regst_desc : pair.second) {
if (regst_desc->regst_desc_type().has_data_regst_desc()) {
// Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts().
CHECK(data_regst_desc == nullptr);
data_regst_desc = CHECK_NOTNULL(regst_desc.get());
} else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) {
// do nothing.
} else {
UNIMPLEMENTED();
}
}
}
if (data_regst_desc != nullptr) {
for (const auto& ibn : op_node()->op().input_bns()) {
// Only one fake data regst is creatd and just use it for all input_bns as a placeholder.
data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn));
}
}
}

void CompTaskNode::EraseFakeRegstsIf() {
for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) {
EraseConsumedRegstsByName(fake_consumed_regst_name);
}
fake_consumed_regst_names_.clear();
}

std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); }

void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) {
TaskNode::InitFromProtoExceptConsumedRegsts(proto);
parallel_ctx_ = proto.parallel_ctx();
}

void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const {
TaskNode::ToProto(task_proto, check);
*(task_proto->mutable_parallel_ctx()) = parallel_ctx_;
Expand Down
11 changes: 10 additions & 1 deletion oneflow/core/graph/compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@ limitations under the License.

#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/fake_consumed_regst_provider.h"
#include "oneflow/core/device/cuda_util.h"

namespace oneflow {

class CompTaskNode : public TaskNode {
class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider {
public:
OF_DISALLOW_COPY_AND_MOVE(CompTaskNode);
CompTaskNode() = default;
virtual ~CompTaskNode() = default;

virtual void ToProto(TaskProto*, bool check) const override;
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override;
void ConsumeFakeRegstsIf() override;
void EraseFakeRegstsIf() override;

// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks.
virtual void ConsumeFakeRegsts() = 0;
void ConsumeFakeRegst(const std::string& regst_name);

// parallel_ctx_
int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }
Expand Down Expand Up @@ -58,6 +66,7 @@ class CompTaskNode : public TaskNode {
private:
ParallelContext parallel_ctx_;
const OpNode* op_node_;
HashSet<std::string> fake_consumed_regst_names_;
};

class OpCompTaskNodeCreator {
Expand Down
35 changes: 35 additions & 0 deletions oneflow/core/graph/fake_consumed_regst_provider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_
#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_

namespace oneflow {

// Provide a compute task node with a fake input regst, and its output regst can be inferred using
// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob
// desc, mainly to ensure that the transport task node has the correct input blob desc.
class FakeConsumedRegstProvider {
public:
FakeConsumedRegstProvider() = default;
virtual ~FakeConsumedRegstProvider() = default;

virtual void ConsumeFakeRegstsIf() = 0;
virtual void EraseFakeRegstsIf() = 0;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_
1 change: 1 addition & 0 deletions oneflow/core/graph/normal_forward_compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode {

void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;

TaskType GetTaskType() const override { return TaskType::kNormalForward; }

Expand Down
64 changes: 63 additions & 1 deletion oneflow/core/graph/task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/memory/memory_case_util.h"
#include "oneflow/core/graph/task_graph_rebuild_ctx.h"
Expand Down Expand Up @@ -203,6 +204,38 @@ std::string TaskNode::VisualStr() const {

bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }

void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) {
// Step1: init some scalar items.
CHECK(task_proto.task_type() == GetTaskType());
machine_id_ = task_proto.machine_id();
thrd_id_ = task_proto.thrd_id();
task_id_ = task_proto.task_id();
new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_)));
CHECK(task_proto.job_id() == GlobalJobDesc().job_id());
chain_id_ = task_proto.chain_id();
order_in_chain_ = task_proto.order_in_chain();
// Step2: check exec_gph empty.
CHECK(task_proto.exec_sequence().exec_node().empty());
// Step3: init produced_regst.
for (const auto& pair : task_proto.produced_regst_desc()) {
const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem());
// regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto.
regst_desc->InitFromProtoExceptConsumers(pair.second);
}
}

Maybe<void> TaskNode::InitConsumedRegstsFromProto(
const TaskProto& task_proto,
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) {
// init consumed_regst.
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
for (int64_t regst_desc_id : pair.second.regst_desc_id()) {
ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id)));
}
}
return Maybe<void>::Ok();
}

void TaskNode::ToProto(TaskProto* task_proto, bool check) const {
// Step1: process some scalar items.
task_proto->set_task_type(GetTaskType());
Expand Down Expand Up @@ -265,16 +298,34 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) {
}

void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) {
if (edge->HasRegst(name)) { return; }
edge->AddRegst(name, GetProducedRegst(name));
}

std::shared_ptr<RegstDesc> TaskNode::GetAndCheckRegst(const std::string& name,
bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) const {
auto iter = produced_regsts_.find(name);
if (iter == produced_regsts_.end()) { return nullptr; }
const auto& regst = (iter->second);
CHECK_EQ(regst->min_register_num(), min_register_num);
CHECK_EQ(regst->max_register_num(), max_register_num);
CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem);
return regst;
}

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) {
return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum);
}

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) {
// Because the Regst of separate compilation is not created in order, some Regst may have been
// built. This implementation can avoid ProduceRegst being called multiple times.
const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num);
if (regst) { return regst; }
RegstDescTypeProto regst_desc_type;
regst_desc_type.mutable_data_regst_desc();
return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type);
Expand Down Expand Up @@ -353,8 +404,14 @@ std::shared_ptr<RegstDesc> TaskEdge::GetRegst(const std::string& name_in_produce
return name_in_producer2regst_.at(name_in_producer);
}

bool TaskEdge::HasRegst(const std::string& name_in_producer) const {
return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end());
}

std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const {
CHECK_EQ(name_in_producer2regst_.size(), 1);
CHECK_EQ(name_in_producer2regst_.size(), 1)
<< "edge: " << this << ", src: " << src_node()->task_id()
<< ", dst: " << dst_node()->task_id();
return name_in_producer2regst_.begin()->second;
}

Expand All @@ -367,6 +424,11 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {

void TaskEdge::AddRegst(const std::string& name_in_producer,
const std::shared_ptr<RegstDesc>& regst) {
if (HasRegst(name_in_producer)) {
CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id()
== regst->regst_desc_id());
return;
}
CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second);
}

Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/graph/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::string VisualStr() const override;
virtual bool IsMeaningLess();
void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); }
// Used to create task node from proto in plan separation compilation.
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto);
Maybe<void> InitConsumedRegstsFromProto(
const TaskProto& task_proto,
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id);
virtual void ToProto(TaskProto* task_proto, bool check) const;
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual MemZoneId MemZoneId121() const;
Expand Down Expand Up @@ -150,6 +155,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> {

private:
void UpdateTaskId();
std::shared_ptr<RegstDesc> GetAndCheckRegst(const std::string& name, bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) const;

int64_t machine_id_;
int64_t thrd_id_;
Expand All @@ -172,6 +180,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
~TaskEdge() override = default;

std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const;
bool HasRegst(const std::string& name_in_producer) const;
std::shared_ptr<RegstDesc> GetSoleRegst() const;
std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const;
const HashSet<LogicalBlobId>& GetLbis() const { return lbis_; }
Expand All @@ -181,6 +190,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); }

void CheckRegstLbiValid() const;
bool HasRegst() const { return !name_in_producer2regst_.empty(); }

Maybe<void> InitFromProto(const TaskEdgeProto& proto,
const TaskGraphRebuildCtx& task_graph_rebuild_ctx);
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/graph_impl/acc_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AccCompTaskNode final : public CompTaskNode {
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
};

void AccCompTaskNode::ProduceAllRegstsAndBindEdges() {
Expand All @@ -35,6 +36,7 @@ void AccCompTaskNode::ProduceAllRegstsAndBindEdges() {
}

void AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }
void AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AccCtrlTickCompTaskNode final : public CompTaskNode {
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void BuildExecGphAndRegst() override;
void ConsumeFakeRegsts() override;
};

void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() {
Expand All @@ -38,6 +39,8 @@ void AccCtrlTickCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst());
}

void AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/graph_impl/acc_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AccTickCompTaskNode final : public CompTaskNode {
TaskType GetTaskType() const override { return TaskType::kAccTick; }
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -37,6 +38,7 @@ void AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() {
void AccTickCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst());
}
void AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccTickCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/callback_notify_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CallbackNotifyCompTaskNode final : public CompTaskNode {
private:
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -38,6 +39,8 @@ void CallbackNotifyCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); });
}

void CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CallbackNotifyCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = this->op();
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/case_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CaseCompTaskNode final : public CompTaskNode {

void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;

TaskType GetTaskType() const override { return TaskType::kCase; }

Expand All @@ -36,6 +37,8 @@ class CaseCompTaskNode final : public CompTaskNode {

void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }

void CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() {
HashMap<LogicalBlobId, int64_t> lbi2obn_id;
FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode {
private:
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -43,6 +44,8 @@ void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); });
}

void CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = op();
Expand Down
Loading

0 comments on commit 9cd5d64

Please sign in to comment.