Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Complete compute interceptor #37485

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 116 additions & 17 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,130 @@ ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)

void ComputeInterceptor::PrepareDeps() {
auto& upstream = GetTaskNode()->upstream();
upstream_deps_.insert(upstream.begin(), upstream.end());
auto& downstream = GetTaskNode()->downstream();

// TODO(wangxi): get from task node
int64_t in_buff_size = std::numeric_limits<int64_t>::max();
int64_t out_buff_size = 2;

for (auto up_id : upstream) {
in_readys_.emplace(up_id, std::make_pair(in_buff_size, 0));
}
for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
}
}

void ComputeInterceptor::IncreaseReady(int64_t up_id) {
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));

auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
PADDLE_ENFORCE_LE(ready_size, max_ready_size,
platform::errors::OutOfRange(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld",
up_id, ready_size, max_ready_size));
it->second.second = ready_size;
}

void ComputeInterceptor::DecreaseBuff(int64_t down_id) {
auto it = out_buffs_.find(down_id);
PADDLE_ENFORCE_NE(it, out_buffs_.end(),
platform::errors::NotFound(
"Cannot find downstream=%lld in out_buffs.", down_id));
auto used_size = it->second.second;
used_size -= 1;
PADDLE_ENFORCE_GE(
used_size, 0,
platform::errors::OutOfRange(
"downstream=%lld used buff size must >= 0, but now equal %lld",
down_id, used_size));
it->second.second = used_size;
}

bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) return false;
}
return true;
}

bool ComputeInterceptor::CanWriteOutput() {
for (auto& outs : out_buffs_) {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
// full, return false
if (used_size == max_buffer_size) return false;
}
return true;
}

void ComputeInterceptor::SendDataReadyToDownStream() {
auto& downstream = GetTaskNode()->downstream();
for (auto dst_id : downstream) {
InterceptorMessage dst_msg;
dst_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id;
Send(dst_id, dst_msg);
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
auto max_buff_size = outs.second.first;
auto used_size = outs.second.second;
used_size += 1;
PADDLE_ENFORCE_LE(
used_size, max_buff_size,
platform::errors::OutOfRange("downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld",
down_id, used_size, max_buff_size));
outs.second.second = used_size;

InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id;
Send(down_id, ready_msg);
}
}

void ComputeInterceptor::ReplyCompletedToUpStream() {
for (auto& ins : in_readys_) {
auto up_id = ins.first;
auto ready_size = ins.second.second;
ready_size -= 1;
PADDLE_ENFORCE_GE(
ready_size, 0,
platform::errors::OutOfRange(
"upstream=%lld ready_size must >= 0, but now got %lld", up_id,
ready_size));
ins.second.second = ready_size;

InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg);
}
}

void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// TODO(wangxi): add op run

// send to downstream and increase buff used
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
}
}

void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
auto src_id = msg.src_id();
upstream_deps_.erase(src_id);

// all input is ready
if (upstream_deps_.empty()) {
// TODO(wangxi): op run
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
SendDataReadyToDownStream();
PrepareDeps();
}
IncreaseReady(msg.src_id());
Run();
} else if (msg.message_type() == DATE_IS_USELESS) {
DecreaseBuff(msg.src_id());
Run();
}
}

Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/distributed/fleet_executor/compute_interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <utility>

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"

namespace paddle {
Expand All @@ -25,12 +27,24 @@ class ComputeInterceptor : public Interceptor {

void PrepareDeps();

void IncreaseReady(int64_t up_id);
void DecreaseBuff(int64_t down_id);
bool IsInputReady();
bool CanWriteOutput();

void SendDataReadyToDownStream();
void ReplyCompletedToUpStream();

void Run();
void Compute(const InterceptorMessage& msg);

private:
std::unordered_set<int64_t> upstream_deps_;
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a message RESET, which was designed to reset the step_ to 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那感觉还得加一个start的interceptor,发送开始信息,然后到了指定micro_step后,再发reset信息。

int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
std::map<int64_t, std::pair<int64_t, int64_t>> out_buffs_{};
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,64 @@ class StopInterceptor : public Interceptor {
void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
count_ += 1;
if (count_ == 1) return;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
Send(3, stop);
}
int count_{0};
};

class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}

void NOP(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};

TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}}, {{0, "127.0.0.0:0"}},
"127.0.0.0:0");

Carrier& carrier = Carrier::Instance();

// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 0, 0);

// a->b->c
// a->b->c->d
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
node_c->AddDownstreamTask(3);
node_d->AddUpstreamTask(2);

Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
Interceptor* a =
carrier.SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, std::make_unique<StopInterceptor>(3, node_c));

carrier.SetCreatingFlag(false);

InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// double buff, send twice
a->Send(1, msg);
a->Send(1, msg);
}

Expand Down