Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change to new api in async mode #41022

Merged
merged 5 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 45 additions & 37 deletions paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,26 +532,25 @@ std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
return pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
return pull_sparse_param(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
return pull_sparse(pull_context.sparse_values, table_id,
pull_context.keys, num, is_training);
}
}
}

std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
return push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
Expand All @@ -561,7 +560,7 @@ std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
return push_sparse(table_id, keys, update_values, num);
}
}
}
Expand All @@ -584,11 +583,12 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->update_dim());
values->resize(shard_nums * accessor->GetTableInfo(UPDATE_DIM));
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward((void *)(values->data()), // NOLINT
shard_nums * accessor->update_size());
io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT
shard_nums * accessor->GetTableInfo(UPDATE_SIZE));
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
Expand Down Expand Up @@ -630,21 +630,22 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
Expand All @@ -660,9 +661,11 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id);
auto fea_dim = accessor->GetTableInfo(FEA_DIM);
auto select_size = accessor->GetTableInfo(SELECT_SIZE);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, num_per_shard, regions, region_num,
Expand All @@ -671,7 +674,8 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size = num_per_shard * accessor->select_size();
size_t shard_data_size =
num_per_shard * accessor->GetTableInfo(SELECT_SIZE);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
Expand Down Expand Up @@ -739,8 +743,8 @@ std::future<int32_t> BrpcPsClient::push_dense_param(const Region *regions,
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
size_t shard_data_size = num_per_shard * accessor->update_size();
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
size_t shard_data_size = num_per_shard * accessor->GetTableInfo(UPDATE_SIZE);
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
Expand Down Expand Up @@ -847,7 +851,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
auto value_ptr = value_ptrs[shard_idx];

size_t kv_size = kvs.size();
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);

// 发送RPC请求
auto *push_request = closure->request(shard_idx);
Expand All @@ -856,14 +860,15 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient(
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size()));
push_data->resize(kv_size *
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);

for (int i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], accessor->update_size());
push_data_ptr += accessor->update_size();
memcpy(push_data_ptr, value_ptr[i], accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
Expand All @@ -884,7 +889,7 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
std::future<int> fut = promise->get_future();
auto *accessor = table_accessor(table_id);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
Expand Down Expand Up @@ -962,7 +967,8 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
}

auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();

size_t value_size = accessor->GetTableInfo(SELECT_SIZE);

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
Expand Down Expand Up @@ -1075,7 +1081,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
}

auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
size_t value_size = accessor->GetTableInfo(SELECT_SIZE);

DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
Expand Down Expand Up @@ -1199,7 +1205,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) {
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->update_size();
size_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
Expand Down Expand Up @@ -1359,8 +1365,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
shard_kv_data.kv_num = 0;
continue;
}

uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
Expand Down Expand Up @@ -1506,7 +1511,7 @@ void BrpcPsClient::push_sparse_task_consume() {

void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
const float *another_data) {
size_t col_num = accessor->update_size() / sizeof(float);
size_t col_num = accessor->GetTableInfo(UPDATE_SIZE) / sizeof(float);
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) {
Expand All @@ -1522,7 +1527,7 @@ int BrpcPsClient::push_sparse_async_shard_merge(
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint64_t min_key = UINT64_MAX;
uint32_t value_size = accessor->update_size();
uint32_t value_size = accessor->GetTableInfo(UPDATE_SIZE);

thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
Expand Down Expand Up @@ -1628,8 +1633,9 @@ int BrpcPsClient::push_sparse_async_shard_push(
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
int update_size = accessor->GetTableInfo(UPDATE_SIZE);
push_data->resize(merged_kv_count *
(sizeof(uint64_t) + accessor->update_size()));
(sizeof(uint64_t) + accessor->GetTableInfo(UPDATE_SIZE)));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
Expand All @@ -1638,8 +1644,8 @@ int BrpcPsClient::push_sparse_async_shard_push(
const char *task_data_ptr = merged_value_list[i].data();

memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
accessor->update_size());
push_data_ptr += accessor->update_size();
accessor->GetTableInfo(UPDATE_SIZE));
push_data_ptr += accessor->GetTableInfo(UPDATE_SIZE);
}
PsService_Stub rpc_stub(get_sparse_channel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
Expand All @@ -1654,6 +1660,8 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = table_accessor(table_id);
int fea_dim = accessor->GetTableInfo(FEA_DIM);
int update_dim = accessor->GetTableInfo(UPDATE_DIM);
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
Expand All @@ -1673,11 +1681,11 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
size_t request_call_num = _server_channels.size();

uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);

// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num *
accessor->update_dim());
accessor->GetTableInfo(UPDATE_DIM));
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
Expand Down Expand Up @@ -1806,7 +1814,7 @@ void BrpcPsClient::push_dense_raw_gradient(
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num);
dense_dim_per_shard(accessor->GetTableInfo(FEA_DIM), request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
Expand Down
37 changes: 30 additions & 7 deletions paddle/fluid/distributed/ps/service/brpc_ps_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,

auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data->data(), num);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->pull_dense(res_data->data(), num);

cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
Expand Down Expand Up @@ -264,9 +269,15 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
const float *values =
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values =
(const float *)(request.data().data() + sizeof(uint32_t));
if (table->push_dense(values, num) != 0) {
table_context.num = num;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
}

Expand Down Expand Up @@ -388,7 +399,12 @@ int32_t BrpcPsService::pull_sparse(Table *table,

auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * dim);
table->pull_sparse(res_data->data(), value);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->pull_sparse(res_data->data(), value);

cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
Expand Down Expand Up @@ -421,10 +437,17 @@ int32_t BrpcPsService::push_sparse(Table *table,
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = (const uint64_t *)push_data.data();
table_context.push_context.values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse(keys, values, num) != 0) {
table_context.num = num;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error");
}
return 0;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ struct RequestContext {
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
uint64_t *keys;
float **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ ::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
// uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
// char** select_values =
// reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
pull_sparse_ptr(reinterpret_cast<char**>(pull_context.sparse_values),
table_id, pull_context.keys, num);
}
}

Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/distributed/ps/table/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ struct AccessorInfo {
size_t fea_dim;
};

enum InfoKey {
DIM = 0,
SIZE = 1,
SELECT_SIZE = 2,
SELECT_DIM = 3,
UPDATE_SIZE = 4,
UPDATE_DIM = 5,
MF_SIZE = 6,
FEA_DIM = 7
};

class ValueAccessor {
public:
ValueAccessor() {}
Expand All @@ -79,7 +90,8 @@ class ValueAccessor {
}
virtual int initialize() = 0;

virtual void GetTableInfo(AccessorInfo& info) = 0;
virtual void SetTableInfo(AccessorInfo& info) = 0;
virtual size_t GetTableInfo(InfoKey key) = 0;

// value维度
virtual size_t dim() = 0;
Expand Down
Loading