Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add consume fake regst #10140

Merged
merged 25 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions oneflow/core/graph/compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,57 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge(
return comp_task_nodes;
}

std::shared_ptr<RegstDesc> NewFakeDataRegstDesc() {
auto regst_desc = std::make_shared<RegstDesc>();
regst_desc->mut_regst_desc_type()->mutable_data_regst_desc();
return regst_desc;
}

} // namespace

void CompTaskNode::ConsumeFakeRegst(const std::string& regst_name) {
ConsumeRegst(regst_name, NewFakeDataRegstDesc());
fake_consumed_regst_names_.insert(regst_name);
}

void CompTaskNode::ConsumeFakeRegstsIf() {
ConsumeFakeRegsts();
RegstDesc* data_regst_desc = nullptr;
for (const auto& pair : consumed_regsts()) {
for (const auto& regst_desc : pair.second) {
if (regst_desc->regst_desc_type().has_data_regst_desc()) {
// Only one fake data regst is creatd for each CompTaskNode with ConsumeFakeRegsts().
CHECK(data_regst_desc == nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看懂,表示最多只有一个 data regst ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没看懂,表示最多只有一个 data regst ?

嗯,只创建了一个占位的 fake regst。

void NormalForwardCompTaskNode::ConsumeAllRegsts() {
  ForEachInDataEdge([&](TaskEdge* edge) {
    for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
  });
}

void NormalForwardCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那加下注释吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那加下注释吧

done

data_regst_desc = CHECK_NOTNULL(regst_desc.get());
} else if (regst_desc->regst_desc_type().has_ctrl_regst_desc()) {
// do nothing.
} else {
UNIMPLEMENTED();
}
}
}
if (data_regst_desc != nullptr) {
for (const auto& ibn : op_node()->op().input_bns()) {
// Only one fake data regst is creatd and just use it for all input_bns as a placeholder.
data_regst_desc->AddLbi(op_node()->op().BnInOp2Lbi(ibn));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多个 lbi 指向了这占位的 fake regst

}
}
}

void CompTaskNode::EraseFakeRegstsIf() {
for (const auto& fake_consumed_regst_name : fake_consumed_regst_names_) {
EraseConsumedRegstsByName(fake_consumed_regst_name);
}
fake_consumed_regst_names_.clear();
}

std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); }

void CompTaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& proto) {
TaskNode::InitFromProtoExceptConsumedRegsts(proto);
parallel_ctx_ = proto.parallel_ctx();
}

void CompTaskNode::ToProto(TaskProto* task_proto, bool check) const {
TaskNode::ToProto(task_proto, check);
*(task_proto->mutable_parallel_ctx()) = parallel_ctx_;
Expand Down
11 changes: 10 additions & 1 deletion oneflow/core/graph/compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@ limitations under the License.

#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/graph/fake_consumed_regst_provider.h"
#include "oneflow/core/device/cuda_util.h"

namespace oneflow {

class CompTaskNode : public TaskNode {
class CompTaskNode : public TaskNode, public FakeConsumedRegstProvider {
public:
OF_DISALLOW_COPY_AND_MOVE(CompTaskNode);
CompTaskNode() = default;
virtual ~CompTaskNode() = default;

virtual void ToProto(TaskProto*, bool check) const override;
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto&) override;
void ConsumeFakeRegstsIf() override;
void EraseFakeRegstsIf() override;

// ConsumeFakeRegsts is used for initializing CompTaskNode.consumed_regsts_ on the other ranks.
virtual void ConsumeFakeRegsts() = 0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BoxingTaskGraph 中的 TransportTaskNode 会带有上游的多个 rank 的 ComputeTaskNode,当做 task graph 的 rank 内 infer BlobDesc 时,非本 rank 的 input regst 是不存在的,但是需要一个 fake 数据来保证正常的推理逻辑可以通过。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本质上这个是一个取巧的临时绕过去的办法。 未来要重构出 task graph build 不依赖 input regst 的版本

void ConsumeFakeRegst(const std::string& regst_name);

// parallel_ctx_
int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }
Expand Down Expand Up @@ -58,6 +66,7 @@ class CompTaskNode : public TaskNode {
private:
ParallelContext parallel_ctx_;
const OpNode* op_node_;
HashSet<std::string> fake_consumed_regst_names_;
};

class OpCompTaskNodeCreator {
Expand Down
35 changes: 35 additions & 0 deletions oneflow/core/graph/fake_consumed_regst_provider.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_
#define ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_

namespace oneflow {

// Provide a compute task node with a fake input regst, and its output regst can be inferred using
// SBP + Placement. The fake compute task node can help the task graph of one rank to infer blob
// desc, mainly to ensure that the transport task node has the correct input blob desc.
class FakeConsumedRegstProvider {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixin 接口类

public:
FakeConsumedRegstProvider() = default;
virtual ~FakeConsumedRegstProvider() = default;

virtual void ConsumeFakeRegstsIf() = 0;
virtual void EraseFakeRegstsIf() = 0;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_FAKE_CONSUMED_REGST_PROVIDER_H_
1 change: 1 addition & 0 deletions oneflow/core/graph/normal_forward_compute_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class NormalForwardCompTaskNode final : public CompTaskNode {

void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;

TaskType GetTaskType() const override { return TaskType::kNormalForward; }

Expand Down
64 changes: 63 additions & 1 deletion oneflow/core/graph/task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/memory/memory_case_util.h"
#include "oneflow/core/graph/task_graph_rebuild_ctx.h"
Expand Down Expand Up @@ -203,6 +204,38 @@ std::string TaskNode::VisualStr() const {

bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }

void TaskNode::InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto) {
// Step1: init some scalar items.
CHECK(task_proto.task_type() == GetTaskType());
machine_id_ = task_proto.machine_id();
thrd_id_ = task_proto.thrd_id();
task_id_ = task_proto.task_id();
new_task_id_.reset(new TaskId(DecodeTaskIdFromInt64(task_id_)));
CHECK(task_proto.job_id() == GlobalJobDesc().job_id());
chain_id_ = task_proto.chain_id();
order_in_chain_ = task_proto.order_in_chain();
// Step2: check exec_gph empty.
CHECK(task_proto.exec_sequence().exec_node().empty());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 要检查吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 要检查吗?

检查它非空么,consumed_regst_desc_id 在后面做了消费

// Step3: init produced_regst.
for (const auto& pair : task_proto.produced_regst_desc()) {
const auto& regst_desc = ProduceRegst(pair.first, pair.second.enable_reuse_mem());
// regst_desc->consumers_ will be initialized by RegstDesc::InitConsumersFromProto.
regst_desc->InitFromProtoExceptConsumers(pair.second);
}
}

Maybe<void> TaskNode::InitConsumedRegstsFromProto(
const TaskProto& task_proto,
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id) {
// init consumed_regst.
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consumed_regst_desc_id 这里做了消费

for (int64_t regst_desc_id : pair.second.regst_desc_id()) {
ConsumeRegst(pair.first, JUST(RegstDesc4Id(regst_desc_id)));
}
}
return Maybe<void>::Ok();
}

void TaskNode::ToProto(TaskProto* task_proto, bool check) const {
// Step1: process some scalar items.
task_proto->set_task_type(GetTaskType());
Expand Down Expand Up @@ -265,16 +298,34 @@ RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name) {
}

void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) {
if (edge->HasRegst(name)) { return; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

避免重复绑定

edge->AddRegst(name, GetProducedRegst(name));
}

std::shared_ptr<RegstDesc> TaskNode::GetAndCheckRegst(const std::string& name,
bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) const {
auto iter = produced_regsts_.find(name);
if (iter == produced_regsts_.end()) { return nullptr; }
const auto& regst = (iter->second);
CHECK_EQ(regst->min_register_num(), min_register_num);
CHECK_EQ(regst->max_register_num(), max_register_num);
CHECK_EQ(regst->enable_reuse_mem(), enable_reuse_mem);
return regst;
}

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem) {
return ProduceRegst(name, enable_reuse_mem, 1, kMaxRegisterNum);
}

std::shared_ptr<RegstDesc> TaskNode::ProduceRegst(const std::string& name, bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) {
// Because the Regst of separate compilation is not created in order, some Regst may have been
// built. This implementation can avoid ProduceRegst being called multiple times.
const auto& regst = GetAndCheckRegst(name, enable_reuse_mem, min_register_num, max_register_num);
if (regst) { return regst; }
RegstDescTypeProto regst_desc_type;
regst_desc_type.mutable_data_regst_desc();
return ProduceRegst(name, enable_reuse_mem, min_register_num, max_register_num, regst_desc_type);
Expand Down Expand Up @@ -353,8 +404,14 @@ std::shared_ptr<RegstDesc> TaskEdge::GetRegst(const std::string& name_in_produce
return name_in_producer2regst_.at(name_in_producer);
}

bool TaskEdge::HasRegst(const std::string& name_in_producer) const {
return (name_in_producer2regst_.find(name_in_producer) != name_in_producer2regst_.end());
}

std::shared_ptr<RegstDesc> TaskEdge::GetSoleRegst() const {
CHECK_EQ(name_in_producer2regst_.size(), 1);
CHECK_EQ(name_in_producer2regst_.size(), 1)
<< "edge: " << this << ", src: " << src_node()->task_id()
<< ", dst: " << dst_node()->task_id();
return name_in_producer2regst_.begin()->second;
}

Expand All @@ -367,6 +424,11 @@ std::vector<std::shared_ptr<RegstDesc>> TaskEdge::GetRegsts() const {

void TaskEdge::AddRegst(const std::string& name_in_producer,
const std::shared_ptr<RegstDesc>& regst) {
if (HasRegst(name_in_producer)) {
CHECK(CHECK_JUST(MapAt(name_in_producer2regst_, name_in_producer))->regst_desc_id()
== regst->regst_desc_id());
return;
}
CHECK(name_in_producer2regst_.emplace(name_in_producer, regst).second);
}

Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/graph/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
std::string VisualStr() const override;
virtual bool IsMeaningLess();
void ToProto(TaskProto* task_proto) const { ToProto(task_proto, /*check*/ true); }
// Used to create task node from proto in plan separation compilation.
virtual void InitFromProtoExceptConsumedRegsts(const TaskProto& task_proto);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以注释一下是为了分离编译中途基于 plan 创建 task node 用

done

Maybe<void> InitConsumedRegstsFromProto(
const TaskProto& task_proto,
const std::function<Maybe<RegstDesc>(int64_t regst_desc_id)>& RegstDesc4Id);
virtual void ToProto(TaskProto* task_proto, bool check) const;
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual MemZoneId MemZoneId121() const;
Expand Down Expand Up @@ -150,6 +155,9 @@ class TaskNode : public Node<TaskNode, TaskEdge> {

private:
void UpdateTaskId();
std::shared_ptr<RegstDesc> GetAndCheckRegst(const std::string& name, bool enable_reuse_mem,
int32_t min_register_num,
int32_t max_register_num) const;

int64_t machine_id_;
int64_t thrd_id_;
Expand All @@ -172,6 +180,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
~TaskEdge() override = default;

std::shared_ptr<RegstDesc> GetRegst(const std::string& name_in_producer) const;
bool HasRegst(const std::string& name_in_producer) const;
std::shared_ptr<RegstDesc> GetSoleRegst() const;
std::vector<std::shared_ptr<RegstDesc>> GetRegsts() const;
const HashSet<LogicalBlobId>& GetLbis() const { return lbis_; }
Expand All @@ -181,6 +190,7 @@ class TaskEdge final : public Edge<TaskNode, TaskEdge> {
void AddLbis(const std::vector<LogicalBlobId>& lbis) { lbis_.insert(lbis.begin(), lbis.end()); }

void CheckRegstLbiValid() const;
bool HasRegst() const { return !name_in_producer2regst_.empty(); }

Maybe<void> InitFromProto(const TaskEdgeProto& proto,
const TaskGraphRebuildCtx& task_graph_rebuild_ctx);
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/graph_impl/acc_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AccCompTaskNode final : public CompTaskNode {
void BuildExecGphAndRegst() override;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
};

void AccCompTaskNode::ProduceAllRegstsAndBindEdges() {
Expand All @@ -35,6 +36,7 @@ void AccCompTaskNode::ProduceAllRegstsAndBindEdges() {
}

void AccCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }
void AccCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AccCtrlTickCompTaskNode final : public CompTaskNode {
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void BuildExecGphAndRegst() override;
void ConsumeFakeRegsts() override;
};

void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() {
Expand All @@ -38,6 +39,8 @@ void AccCtrlTickCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst());
}

void AccCtrlTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/graph_impl/acc_tick_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AccTickCompTaskNode final : public CompTaskNode {
TaskType GetTaskType() const override { return TaskType::kAccTick; }
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -37,6 +38,7 @@ void AccTickCompTaskNode::ProduceAllRegstsAndBindEdges() {
void AccTickCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst());
}
void AccTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void AccTickCompTaskNode::BuildExecGphAndRegst() {
std::shared_ptr<RegstDesc> in_regst = GetSoleConsumedRegst("in");
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/callback_notify_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CallbackNotifyCompTaskNode final : public CompTaskNode {
private:
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -38,6 +39,8 @@ void CallbackNotifyCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); });
}

void CallbackNotifyCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CallbackNotifyCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = this->op();
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/graph_impl/case_compute_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CaseCompTaskNode final : public CompTaskNode {

void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;

TaskType GetTaskType() const override { return TaskType::kCase; }

Expand All @@ -36,6 +37,8 @@ class CaseCompTaskNode final : public CompTaskNode {

void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }

void CaseCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() {
HashMap<LogicalBlobId, int64_t> lbi2obn_id;
FOR_RANGE(int64_t, obn_id, 0, op()->output_bns().size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode {
private:
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void ConsumeFakeRegsts() override;
void BuildExecGphAndRegst() override;
};

Expand All @@ -43,6 +44,8 @@ void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); });
}

void CriticalSectionWaitTickCompTaskNode::ConsumeFakeRegsts() { ConsumeFakeRegst("in"); }

void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = op();
Expand Down
Loading