Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Hold the carrier while running for one micro step. #37605

Merged
merged 4 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ void Carrier::Start() {
"Message bus has not been initialized."));
message_bus_instance.Send(tmp_msg);
}
std::unique_lock<std::mutex> lock(running_mutex_);
cond_var_.wait(lock);
}

std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
Expand Down Expand Up @@ -57,6 +58,8 @@ class Carrier final {

void SetCreatingFlag(bool flag);

std::condition_variable& GetCondVar();

void Start();

bool IsInit() const;
Expand All @@ -83,6 +86,9 @@ class Carrier final {
bool creating_interceptors_{true};
std::mutex creating_flag_mutex_;
bool is_init_{false};

std::mutex running_mutex_;
std::condition_variable cond_var_;
};

} // namespace distributed
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,11 @@ void ComputeInterceptor::TryStop() {
Send(down_id, stop);
}
stop_ = true;
}

void ComputeInterceptor::HandleStop(const InterceptorMessage& msg) {
ReceivedStop(msg.src_id());

TryStop();
if (out_buffs_.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有好几个汇节点怎么办

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个逻辑估计后面还要改

// TODO(fleet executor dev) need a better place to notify
StopCarrier();
}
}

void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
Expand All @@ -198,6 +197,8 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
} else if (msg.message_type() == DATE_IS_USELESS) {
DecreaseBuff(msg.src_id());
Run();
} else if (msg.message_type() == STOP) {
ReceivedStop(msg.src_id());
}

TryStop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class ComputeInterceptor : public Interceptor {
void Run();
void Compute(const InterceptorMessage& msg);

void HandleStop(const InterceptorMessage& msg) override;
void ReceivedStop(int64_t up_id);
void TryStop();

Expand Down
20 changes: 12 additions & 8 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

Expand Down Expand Up @@ -50,10 +51,20 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage msg;
msg.set_message_type(STOP);
Send(interceptor_id_, msg);
} else if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
}
}
}

void Interceptor::StopCarrier() {
Carrier& carrier_instance = Carrier::Instance();
std::condition_variable& cond_var = carrier_instance.GetCondVar();
// probably double notify, but ok for ut
cond_var.notify_all();
}

std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
Expand All @@ -80,9 +91,6 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
return MessageBus::Instance().Send(msg);
}

// maybe need a better method for interceptor base
void Interceptor::HandleStop(const InterceptorMessage& msg) { stop_ = true; }

void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
for (;;) {
Expand All @@ -101,11 +109,7 @@ void Interceptor::PoolTheMailbox() {
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";

if (message_type == STOP) {
HandleStop(interceptor_message);
} else {
Handle(interceptor_message);
}
Handle(interceptor_message);

if (stop_) {
// break the pooling thread
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class Interceptor {
// register interceptor handle
void RegisterMsgHandle(MsgHandle handle);

virtual void HandleStop(const InterceptorMessage& msg);

void Handle(const InterceptorMessage& msg);

// return the interceptor id
Expand All @@ -69,6 +67,7 @@ class Interceptor {
protected:
TaskNode* GetTaskNode() const { return node_; }
bool stop_{false};
void StopCarrier();

private:
// pool the local mailbox, parse the Message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class StartInterceptor : public Interceptor {
}

void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
++count_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class PingPongInterceptor : public Interceptor {
}

void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class PingPongInterceptor : public Interceptor {
}

void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
Expand Down
5 changes: 1 addition & 4 deletions python/paddle/fluid/tests/unittests/test_fleet_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ def run_fleet_executor(self, place):
exe.run(empty_program, feed={'x': [1]})

def test_executor_on_single_device(self):
places = [fluid.CPUPlace()]
if fluid.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.run_fleet_executor(place)
self.run_fleet_executor(fluid.CUDAPlace(0))


if __name__ == "__main__":
Expand Down