Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add horizontal federation learning ps feature #44327

Merged
merged 70 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
d9bb853
back fl
ziyoujiyi Mar 25, 2022
6073452
delete ssl cert
ziyoujiyi Mar 25, 2022
66fa8c8
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 25, 2022
4bb3d3f
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 25, 2022
7a02e84
.
ziyoujiyi Mar 25, 2022
883b55a
make warning
ziyoujiyi Mar 26, 2022
f917402
.
ziyoujiyi Mar 26, 2022
fa4ab2e
unittest paral degree
ziyoujiyi Mar 28, 2022
a129afc
solve unittest
ziyoujiyi Mar 28, 2022
a54e061
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
ed7e38f
heter & multi cloud commm ready
ziyoujiyi Mar 29, 2022
3e86455
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
b5a34fc
.
ziyoujiyi Mar 29, 2022
0e4b998
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into de…
ziyoujiyi Mar 29, 2022
eeec283
.
ziyoujiyi Mar 29, 2022
d293d97
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 29, 2022
c1759b5
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 30, 2022
d9aa775
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Mar 31, 2022
7105730
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 2, 2022
73ea318
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 11, 2022
7dc2091
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 19, 2022
2019a5f
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 24, 2022
f22bbcd
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Apr 26, 2022
5019c73
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 9, 2022
9b92deb
fl-ps v1.0
ziyoujiyi May 9, 2022
31f330c
merge dev
ziyoujiyi May 9, 2022
f2fa8ee
.
ziyoujiyi May 9, 2022
6c76994
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 11, 2022
7aadb99
support N + N mode
ziyoujiyi May 11, 2022
001c11c
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into fl_ps
ziyoujiyi May 11, 2022
5f7b4fd
.
ziyoujiyi May 11, 2022
a6f7f29
.
ziyoujiyi May 11, 2022
cbbd5e9
.
ziyoujiyi May 12, 2022
2873622
.
ziyoujiyi May 13, 2022
16ad3c1
delete print
ziyoujiyi May 24, 2022
9a89ba3
.
ziyoujiyi May 25, 2022
2469beb
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi May 25, 2022
acc3898
merge dev
ziyoujiyi May 25, 2022
3c5374d
.
ziyoujiyi May 30, 2022
07bf8ab
.
ziyoujiyi May 30, 2022
25f38c1
.
ziyoujiyi May 30, 2022
53aa15c
fix bug
ziyoujiyi Jun 14, 2022
ff90d84
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jun 14, 2022
096d3da
merge dev
ziyoujiyi Jun 14, 2022
29367c9
.
ziyoujiyi Jun 14, 2022
4dc1657
.
ziyoujiyi Jun 15, 2022
cb424b8
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jul 1, 2022
09fe823
fl-ps with coordinator ready
ziyoujiyi Jul 11, 2022
38ef399
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jul 11, 2022
cc44d9b
merge dev
ziyoujiyi Jul 11, 2022
9081f5d
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jul 11, 2022
a1b5e58
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into fl_ps
ziyoujiyi Jul 11, 2022
1b75c47
merge dev
ziyoujiyi Jul 11, 2022
af4a56a
update message parse only
ziyoujiyi Jul 12, 2022
09f49db
update fl client scheduler
ziyoujiyi Jul 13, 2022
d169c8d
fix bug
ziyoujiyi Jul 14, 2022
d26ed6e
update multithreads sync
ziyoujiyi Jul 14, 2022
4226f19
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jul 14, 2022
baa76c0
Merge branch 'develop' of https://github.com/ziyoujiyi/Paddle into fl_ps
ziyoujiyi Jul 14, 2022
f76ca36
fix ci errors
ziyoujiyi Jul 14, 2022
bb3fd90
update role_maker.py
ziyoujiyi Jul 14, 2022
987079f
update role_maker.py
ziyoujiyi Jul 14, 2022
25459a1
fix ci error: windows py import error
ziyoujiyi Jul 14, 2022
951c284
fix ci error: windows py import error
ziyoujiyi Jul 14, 2022
afe19ca
fix windows ci pylib import error
ziyoujiyi Jul 15, 2022
fd269f8
Merge branch 'PaddlePaddle:develop' into develop
ziyoujiyi Jul 25, 2022
5ba1469
add dump fields & params
ziyoujiyi Jul 25, 2022
b88322a
merge dev
ziyoujiyi Jul 25, 2022
1257de3
try to fix windows import fleet error
ziyoujiyi Jul 25, 2022
6a7f3c9
fix ps FLAGS error
ziyoujiyi Jul 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/external/brpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
# TODO(gongwb): change to de newst repo when they changed
GIT_REPOSITORY "https://github.com/wangjiawei04/brpc"
#GIT_REPOSITORY "https://github.com/ziyoujiyi/brpc" # ssl error in the previous repo(can be mannual fixed)
GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e"
PREFIX ${BRPC_PREFIX_DIR}
UPDATE_COMMAND ""
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ set_source_files_properties(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

set_source_files_properties(
coordinator_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

cc_library(
brpc_utils
SRCS brpc_utils.cc
Expand All @@ -90,6 +94,7 @@ cc_library(
cc_library(
downpour_client
SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
coordinator_client.cc
DEPS eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})

cc_library(
Expand Down
161 changes: 150 additions & 11 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,22 @@
#include <sstream>
#include <string>

#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"

static const int max_port = 65535;

namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace distributed {

DEFINE_int32(pserver_push_dense_merge_limit,
12,
"limit max push_dense local merge requests");
Expand Down Expand Up @@ -66,16 +78,6 @@ DEFINE_int32(pserver_sparse_table_shard_num,
1000,
"sparse table shard for save & load");

namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace distributed {

inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
Expand All @@ -101,7 +103,7 @@ void DownpourPsClientService::service(
}
}

// 启动client端RpcService 用于数据互发等操作
// 启动 client 端 RpcService 用于数据互发等操作
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
Expand All @@ -122,6 +124,35 @@ int32_t BrpcPsClient::StartClientService() {
_server_started = true;
_env->RegistePsClient(
butil::my_ip_cstr(), _server.listen_address().port, _client_id);
VLOG(0) << "BrpcPsClient Service addr: " << butil::my_ip_cstr() << ", "
<< _server.listen_address().port << ", " << _client_id;
return 0;
}

// 启动 FlClientService,用户接收 coordinator 数据
int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) {
_fl_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (self_endpoint.empty()) {
LOG(ERROR) << "fl-ps > fl client endpoint not set";
return -1;
}

if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) {
VLOG(0) << "fl-ps > StartFlClientService failed. Try again.";
auto ip_port = paddle::string::Split(self_endpoint, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_fl_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartFlClientService failed, ip_port= "
<< int_ip_port;
return -1;
}
} else {
VLOG(0) << "fl-ps > StartFlClientService succeed! listen on "
<< self_endpoint;
}
return 0;
}

Expand Down Expand Up @@ -166,6 +197,96 @@ int32_t BrpcPsClient::CreateClient2ClientConnection(
return 0;
}

int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
// 获取 coordinator 列表,并连接
std::string coordinator_ip_port;
std::vector<PSHost> coordinator_list = _env->GetCoordinators();
_coordinator_channels.resize(coordinator_list.size());
for (size_t i = 0; i < coordinator_list.size(); ++i) {
coordinator_ip_port.assign(coordinator_list[i].ip.c_str());
coordinator_ip_port.append(":");
coordinator_ip_port.append(std::to_string(coordinator_list[i].port));
VLOG(0) << "fl-ps > BrpcFlclient connetcting to coordinator: "
<< coordinator_ip_port;
for (size_t j = 0; j < _coordinator_channels[i].size(); ++j) {
_coordinator_channels[i][j].reset(new brpc::Channel());
if (_coordinator_channels[i][j]->Init(
coordinator_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< coordinator_ip_port << " Failed! Try again.";
std::string int_ip_port = GetIntTypeEndpoint(coordinator_list[i].ip,
coordinator_list[i].port);
if (_coordinator_channels[i][j]->Init(
int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< int_ip_port << " Failed!";
return -1;
}
}
}
}
StartFlClientService(self_endpoint);
VLOG(0) << "fl-ps > InitializeFlWorker finished!";
return 0;
}

void BrpcPsClient::PushFLClientInfoSync(const std::string &fl_client_info) {
size_t request_call_num = _coordinator_channels.size();
FlClientBrpcClosure *closure =
new FlClientBrpcClosure(request_call_num, [request_call_num](void *done) {
auto *closure = reinterpret_cast<FlClientBrpcClosure *>(done);
int ret = 0;
for (size_t i = 0; i < request_call_num; i++) {
if (closure->check_response(i, PUSH_FL_CLIENT_INFO_SYNC) != 0) {
LOG(ERROR) << "fl-ps > PushFLClientInfoSync response from "
"coordinator is failed";
ret = -1;
return;
} else {
VLOG(0) << "fl-ps > rpc service call cost time: "
<< (closure->cntl(i)->latency_us() / 1000) << " ms";
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PUSH_FL_CLIENT_INFO_SYNC);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->set_str_params(fl_client_info);
brpc::Channel *rpc_channel = _coordinator_channels[0][0].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "_coordinator_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // CoordinatorService
rpc_stub.FLService(
closure->cntl(i), closure->request(i), closure->response(i), closure);
fut.wait();
}
VLOG(0) << "fl-ps > PushFLClientInfoSync finished, client id: " << _client_id;
return;
}

std::string BrpcPsClient::PullFlStrategy() {
while (!_service._is_fl_strategy_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(0) << "fl-ps > waiting for fl strategy returned from coordinator";
}
_service._is_fl_strategy_ready =
false; // only support single thread, no need for multi-threads
return _service._fl_strategy;
}

int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;

Expand Down Expand Up @@ -300,6 +421,24 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
return data;
}

int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}

std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
Expand Down
75 changes: 71 additions & 4 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
Expand Down Expand Up @@ -56,16 +57,71 @@ class DownpourPsClientService : public PsService {
_rank = rank_id;
return 0;
}
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;

virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done);

virtual void FLService(::google::protobuf::RpcController *controller,
const CoordinatorReqMessage *request,
CoordinatorResMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
size_t client_id = request->client_id();
CHECK(_client->_client_id == client_id)
<< "request client id not matched self";
_fl_strategy = request->str_params();
_is_fl_strategy_ready = true;
response->set_err_code(0);
response->set_err_msg("");
VLOG(0) << "fl-ps > DownpourPsClientService::FLService finished!";
return;
}

public:
std::string _fl_strategy;
bool _is_fl_strategy_ready = false;

protected:
size_t _rank;
PSClient *_client;
};

class FlClientBrpcClosure : public PSClientClosure {
public:
FlClientBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;

_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~FlClientBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
CoordinatorReqMessage *request(size_t i) { return &_requests[i]; }
CoordinatorResMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);

private:
std::atomic<int32_t> _waiting_num;
std::vector<CoordinatorReqMessage> _requests;
std::vector<CoordinatorResMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};

class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
Expand Down Expand Up @@ -267,6 +323,14 @@ class BrpcPsClient : public PSClient {
}
int32_t Initialize() override;

// for fl
public:
virtual int32_t InitializeFlWorker(const std::string &self_endpoint);
int32_t StartFlClientService(const std::string &self_endpoint);
virtual void PushFLClientInfoSync(const std::string &fl_client_info);
std::string PullFlStrategy();
// for fl

private:
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
Expand Down Expand Up @@ -320,6 +384,8 @@ class BrpcPsClient : public PSClient {
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_coordinator_channels; // client2coordinator
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
Expand Down Expand Up @@ -360,6 +426,7 @@ class BrpcPsClient : public PSClient {
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
brpc::Server _fl_server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/ps/service/communicator/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)

set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
set_source_files_properties(
communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

Expand Down
Loading