Skip to content

Commit

Permalink
cherry pick save/load in the_one_ps (#37461)
Browse files Browse the repository at this point in the history
* save/load in ps runtime(the_one_ps) (#36097)

* add trainer desc config to distributed strategy

* code style modified

* data_feed set lod

* fix bug

* code style

* fix bug

* save load

* save load

* save unittest

* add unittest of the_one_ps

* unittest

* add todo in communicator sendsparse

* fix bug in save_inference_model (#37362)
  • Loading branch information
esythan authored Nov 23, 2021
1 parent d5e73f0 commit 58a5113
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 12 deletions.
23 changes: 23 additions & 0 deletions paddle/fluid/distributed/service/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}

// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/

++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
Expand Down Expand Up @@ -353,6 +365,17 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
return;
}

void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
return;
}

void Communicator::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/service/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ class Communicator {

virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);

virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);

virtual void Start() = 0;

virtual void Stop() = 0;
Expand Down
17 changes: 12 additions & 5 deletions paddle/fluid/distributed/table/common_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,25 @@ int32_t CommonSparseTable::set_global_lr(float* lr) {
return 0;
}

int32_t CommonSparseTable::load(const std::string& path,
int32_t CommonSparseTable::load(const std::string& dirname,
const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock();
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
auto varname = _config.common().table_name();
std::string var_store =
string::Sprintf("%s/%s%s", dirname, varname, PSERVER_SAVE_SUFFIX);
std::string shard_var_pre =
string::Sprintf("%s.block%d", varname, _shard_idx);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);

LoadFromText(value_, meta_, _shard_idx, _shard_num, task_pool_size_,
&shard_values_);
rwlock_->UNLock();
auto end = GetCurrentUS();

auto varname = _config.common().table_name();
VLOG(0) << "load " << varname << " with value: " << path
<< " , meta: " << param
VLOG(0) << "load " << varname << " with value: " << value_
<< " , meta: " << meta_
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";

return 0;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ void BindDistCommunicator(py::module* m) {
.def("start", &Communicator::Start)
.def("push_sparse_param", &Communicator::RpcSendSparseParam)
.def("is_running", &Communicator::IsRunning)
.def("init_params", &Communicator::InitParams);
.def("init_params", &Communicator::InitParams)
.def("pull_dense", &Communicator::PullDense);
// .def("recv", &Communicator::RecvNoBarrier);
}

Expand Down
72 changes: 66 additions & 6 deletions python/paddle/distributed/fleet/runtime/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,11 +868,11 @@ def _init_server(self, dirname=None, var_names=None, **kwargs):

for var_name in load_varnames:
table_id = sparse_table_maps[var_name]
path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
"{}.block{}.txt".format(var_name, pserver_id))
meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
"{}.block{}.meta".format(var_name, pserver_id))
self._server.load_sparse(path, meta, table_id)
# path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
# "{}.block{}.txt".format(var_name, pserver_id))
# meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
# "{}.block{}.meta".format(var_name, pserver_id))
self._server.load_sparse(dirname, "0", table_id)

def _run_server(self):
if self.role_maker._is_heter_worker():
Expand Down Expand Up @@ -967,8 +967,12 @@ def _save_distributed_persistables(self,
TheOnePSRuntime.__exclude_vars(saved_varnames),
main_program.list_vars()))

self._communicator.pull_dense(denses)

import paddle
for var in remaining_vars:
# if var.name not in recv_dense_varnames:
# continue
tensor = var.get_value()
paddle.save(
tensor, os.path.join(dirname, var.name), use_binary_format=True)
Expand Down Expand Up @@ -1063,8 +1067,64 @@ def _save_inference_model(self, *args, **kwargs):
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)

def _load_sparse_params(self, dirname, context, main_program, mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = []
for id, names in context.items():
if names[0] not in distributed_varnames:
# TODO: only load sparse param from local
warnings.warn("varname is not in distributed_varnames, pass")
# load sparse & distributed param on server
self._worker.load_one_table(id, dirname, mode)
values.extend(names)
return values

def _load_distributed_persistables(self, dirname, main_program=None,
mode=0):
if main_program is None:
main_program = self.compiled_strategy.get_origin_ps_main_program()

if isinstance(main_program, CompiledProgram):
raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
)

denses = self.compiled_strategy.get_the_one_recv_context(
is_dense=True,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
sparses = self.compiled_strategy.get_the_one_recv_context(
is_dense=False,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)

sparse_varnames = self._load_sparse_params(dirname, sparses,
main_program, mode)

recv_dense_varnames = []
for id, names in denses.items():
recv_dense_varnames.extend(names)

loaded_varnames = sparse_varnames

remaining_vars = list(
filter(
TheOnePSRuntime.__exclude_vars(loaded_varnames),
main_program.list_vars()))

import paddle
for var in remaining_vars:
if var.name not in recv_dense_varnames:
continue
tensor = paddle.load(os.path.join(dirname, var.name))
var.set_value(tensor)

self._communicator.init_params(denses)

def load_model(self, path, mode):
self._worker.load_model(path, mode)
self._load_distributed_persistables(path, mode=mode)

def _shrink(self, threshold):
import paddle.distributed.fleet as fleet
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def recv(self):
def init_params(self, context):
self.communicator_.init_params(context)

def pull_dense(self, context):
self.communicator_.pull_dense(context)

def push_sparse_param(self, var_name, table_id=-1, scope=global_scope()):
if not self.is_running():
raise ValueError(
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fleet_base_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ def test_ps_minimize(self):

input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_slot = paddle.fluid.layers.data(
name="slot", shape=[1], dtype='int64')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')

emb = paddle.fluid.layers.embedding(
input=input_slot, size=[10, 9], is_sparse=True)
input_x = paddle.concat(x=[input_x, emb], axis=1)
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
Expand All @@ -63,11 +68,14 @@ def test_ps_minimize(self):
compiled_prog = fluid.compiler.CompiledProgram(
fluid.default_main_program())

fleet.init_worker()
fleet.fleet.save(dirname="/tmp", feed=['x', 'y'], fetch=[avg_cost])
fleet.fleet.save(
dirname="/tmp", feed=[input_x, input_y], fetch=[avg_cost])
fleet.fleet.save(dirname="/tmp")

fleet.load_model(path="/tmp", mode=0)

self.assertRaises(
Exception,
fleet.save_inference_model,
Expand Down

0 comments on commit 58a5113

Please sign in to comment.