Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cvm test #14

Merged
merged 3 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions paddle/fluid/distributed/ps/wrapper/fleet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
std::shared_ptr<paddle::distributed::PSClient> FleetWrapper::worker_ptr_ = NULL;

int FleetWrapper::RegisterHeterCallback(HeterCallBackFunc handler) {
VLOG(0) << "RegisterHeterCallback support later";
return 0;
}

int32_t FleetWrapper::CopyTable(const uint64_t src_table_id,
const uint64_t dest_table_id) {
VLOG(0) << "CopyTable support later";
return 0;
}

int32_t FleetWrapper::CopyTableByFeasign(
const uint64_t src_table_id, const uint64_t dest_table_id,
const std::vector<uint64_t>& feasign_list) {
VLOG(0) << "CopyTableByFeasign support later";
return 0;
}

void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms,
int max_retry) {
Expand Down Expand Up @@ -166,7 +184,7 @@ void FleetWrapper::StopServer() {

void FleetWrapper::FinalizeWorker() {
VLOG(3) << "Going to finalize worker";
pserver_ptr_->finalize_worker();
worker_ptr_->finalize_worker();
}

void FleetWrapper::BarrierWithTable(uint32_t barrier_type) {
Expand Down Expand Up @@ -374,6 +392,7 @@ void FleetWrapper::PullDenseVarsAsync(
}
auto status = worker_ptr_->pull_dense(regions.data(), regions.size(), tid);
pull_dense_status->push_back(std::move(status));
VLOG(0) << "debug zcb pscore fleet->PullDenseVarsAsync ret";
}

void FleetWrapper::PullDenseVarsSync(
Expand Down Expand Up @@ -431,7 +450,7 @@ void FleetWrapper::PushDenseVarsAsync(
float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
VLOG(0) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
<< table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
<< g[tensor->numel() - 1];
}
Expand Down Expand Up @@ -741,8 +760,13 @@ void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope,
}

void FleetWrapper::ClientFlush() {
auto ret = pserver_ptr_->_worker_ptr->flush();
VLOG(0) << "debug zcb begin client flush";
auto ret = worker_ptr_->flush();
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "Client Flush failed";
}
}

int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/distributed/ps/wrapper/fleet.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ class FleetWrapper {
client2client_max_retry_ = 3;
}

// TODO(zhaocaibei123: later)
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);

int32_t CopyTableByFeasign(const uint64_t src_table_id,
const uint64_t dest_table_id,
const std::vector<uint64_t>& feasign_list);

typedef std::function<void(int, int)> HeterCallBackFunc;
int RegisterHeterCallback(HeterCallBackFunc handler);

// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include <utility> // NOLINT
#include <vector>

#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_util.h"
Expand Down Expand Up @@ -108,6 +109,7 @@ class PullDenseWorker {

private:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::distributed::FleetWrapper> new_fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
Scope* root_scope_;
Expand Down Expand Up @@ -350,7 +352,7 @@ class DownpourLiteWorker : public HogwildWorker {
virtual void TrainFilesWithProfiler();

protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::distributed::FleetWrapper> fleet_ptr_;
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
void FillSparseValue(size_t table_id);
void PushGradients();
Expand Down
16 changes: 13 additions & 3 deletions paddle/fluid/framework/dist_multi_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
Expand Down Expand Up @@ -54,15 +55,19 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
workers_[i]->SetWorkerNum(thread_num_);
}

VLOG(3) << "going to initialize pull dense worker";
VLOG(0) << "going to initialize pull dense worker";
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
VLOG(3) << "initialize pull dense worker";
VLOG(0) << "initialize pull dense worker";
SetDebug(trainer_desc.debug());
}

void DistMultiTrainer::RegisterHeterCallback() {
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr = FleetWrapper::GetInstance();
#else
auto fleet_ptr = paddle::distributed::FleetWrapper::GetInstance();
#endif
fleet_ptr->RegisterHeterCallback(
[this](int worker, int taskid) { workers_[worker]->Schedule(taskid); });
}
Expand Down Expand Up @@ -176,8 +181,13 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_->Stop();
root_scope_->DropKids();

// flush local client push queue
// flush local client push queue
#ifdef PADDLE_WITH_PSLIB
auto fleet_ptr_ = FleetWrapper::GetInstance();
#else
auto fleet_ptr_ = paddle::distributed::FleetWrapper::GetInstance();
VLOG(0) << "debug zcb dist multi trainer call client-> flush";
#endif
fleet_ptr_->ClientFlush();
}

Expand Down
Loading