Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Add sync method #37167

Merged
merged 20 commits into from
Nov 16, 2021
50 changes: 43 additions & 7 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ bool Carrier::EnqueueInterceptorMessage(
// handle control message
return true;
} else {
if (creating_interceptors_) {
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
Expand Down Expand Up @@ -70,16 +77,45 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
return ptr;
}

void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
if (!flag) {
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}

void Carrier::HandleTmpMessages() {
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}

void Carrier::CreateInterceptors() {
// create each Interceptor
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
if (!interceptor_id_to_node_.empty()) {
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

// TODO(wangxi): use node_type to select different Interceptor
auto interceptor = std::make_unique<Interceptor>(interceptor_id, task_node);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor for " << interceptor_id;
// TODO(wangxi): use node_type to select different Interceptor
auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_interceptors_ = false;
HandleTmpMessages();
}
}

Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
Expand Down Expand Up @@ -53,6 +54,8 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);

void SetCreatingFlag(bool flag);

DISABLE_COPY_AND_ASSIGN(Carrier);

private:
Expand All @@ -61,12 +64,17 @@ class Carrier final {
// create each Interceptor
void CreateInterceptors();

void HandleTmpMessages();

// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;

// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;

std::vector<InterceptorMessage> message_tmp_{};
bool creating_interceptors_{true};
};

} // namespace distributed
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
return true;
}

void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
MessageBus::Instance().Send(msg);
return MessageBus::Instance().Send(msg);
}

void Interceptor::PoolTheMailbox() {
Expand All @@ -76,10 +76,12 @@ void Interceptor::PoolTheMailbox() {
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << interceptor_id_ << " has received a message: " << message_type
<< ".";
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
if (message_type == STOP) {
// break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break;
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Interceptor {
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);

void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT

DISABLE_COPY_AND_ASSIGN(Interceptor);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from: "
<< request->src_id()
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
response->set_rst(true);
// call interceptor manager's method to handle the message
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
int64_t src_id = interceptor_message.src_id();
int64_t dst_id = interceptor_message.dst_id();
if (IsSameRank(src_id, dst_id)) {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
<< ", which are same ranks.";
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return SendIntraRank(interceptor_message);
} else {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
<< ", which are different ranks.";
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times
Expand Down Expand Up @@ -155,6 +156,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
"Cannot find rank for dst interceptor id %lld. "
"Init error.",
dst_id));
VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
const char* dst_ip_for_brpc = dst_ip->second.c_str();
brpc::Channel channel;
brpc::ChannelOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ TEST(InterceptorTest, PingPong) {

Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr));

carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
carrier.SetCreatingFlag(false);

InterceptorMessage msg;
a->Send(1, msg);
Expand Down