Skip to content

Commit

Permalink
[GPUPS]FleetWrapper initialize (#44441)
Browse files Browse the repository at this point in the history
* fix FleetWrapper initialize
  • Loading branch information
zmxdream authored Jul 20, 2022
1 parent 0e2dd2f commit 28cb006
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/fleet_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace framework {
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std::shared_ptr<FleetWrapper> FleetWrapper::s_instance_ = NULL;
bool FleetWrapper::is_initialized_ = false;
std::mutex FleetWrapper::ins_mutex;

#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> FleetWrapper::pslib_ptr_ = NULL;
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/framework/fleet/fleet_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include <atomic>
#include <ctime>
#include <map>
#include <mutex>
#include <random>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -381,8 +382,11 @@ class FleetWrapper {
void Revert();
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
{
std::lock_guard<std::mutex> lk(ins_mutex);
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
}
}
return s_instance_;
}
Expand All @@ -397,6 +401,7 @@ class FleetWrapper {

private:
static std::shared_ptr<FleetWrapper> s_instance_;
static std::mutex ins_mutex;
#ifdef PADDLE_WITH_PSLIB
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/pybind/fleet_wrapper_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindFleetWrapper(py::module* m) {
py::class_<framework::FleetWrapper>(*m, "Fleet")
.def(py::init())
py::class_<framework::FleetWrapper, std::shared_ptr<framework::FleetWrapper>>(
*m, "Fleet")
.def(py::init([]() { return framework::FleetWrapper::GetInstance(); }))
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("pull_dense", &framework::FleetWrapper::PullDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def init(self, role_maker=None):
if role_maker is None:
role_maker = MPISymetricRoleMaker()
super(FleetTranspiler, self).init(role_maker)
self._fleet_ptr = core.Fleet()
if self._fleet_ptr is None:
self._fleet_ptr = core.Fleet()

def _init_transpiler_worker(self):
"""
Expand Down

0 comments on commit 28cb006

Please sign in to comment.